Unverified Commit dceae851 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

Added MoE parallel (#127)

parent 42741dd4
...@@ -66,7 +66,7 @@ class CosineAnnealingWarmupLR(WarmupScheduler): ...@@ -66,7 +66,7 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1): def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1):
base_scheduler = _CosineAnnealingLR( base_scheduler = _CosineAnnealingLR(
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler) super().__init__(optimizer, warmup_steps, base_scheduler)
......
...@@ -17,6 +17,7 @@ import torch.distributed as dist ...@@ -17,6 +17,7 @@ import torch.distributed as dist
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
...@@ -91,6 +92,10 @@ def is_model_parallel_parameter(p): ...@@ -91,6 +92,10 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_moe_parallel_parameter(p):
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
def _calc_l2_norm(grads): def _calc_l2_norm(grads):
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
...@@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -165,26 +170,37 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
no_tensor_parallel_grads = [] no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor) tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data)
else: else:
no_tensor_parallel_grads.append(p.grad.data) no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0: if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm( tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm( no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp( no_tensor_parallel_norm = _calc_lp(
no_tensor_parallel_grads, norm_type) no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, dist.all_reduce(tensor_parallel_norm,
op=dist.ReduceOp.SUM, op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR)) group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, dist.all_reduce(total_norm,
......
import torch
import torch.nn as nn
from colossalai.nn.layer import WrappedDropPath as DropPath
class TransformerLayer(nn.Module):
"""Transformer layer builder.
"""
def __init__(self,
att: nn.Module,
ffn: nn.Module,
norm1: nn.Module,
norm2: nn.Module,
droppath=None,
droppath_rate: float = 0):
super().__init__()
self.att = att
self.ffn = ffn
self.norm1 = norm1
self.norm2 = norm2
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
def forward(self, x):
x = x + self.droppath(self.att(self.norm1(x)))
x = x + self.droppath(self.ffn(self.norm2(x)))
return x
import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
from colossalai.utils import get_current_device
class VanillaSelfAttention(nn.Module):
"""Standard ViT self attention.
"""
def __init__(self,
d_model: int,
n_heads: int,
d_kv: int,
attention_drop: float = 0,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
self.n_heads = n_heads
self.d_kv = d_kv
self.scale = 1.0 / math.sqrt(self.d_kv)
self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device())
self.softmax = nn.Softmax(dim=-1)
self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1
self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device())
self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2
def forward(self, x):
qkv = self.dense1(x)
new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv)
qkv = qkv.view(*new_shape)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[:]
x = torch.matmul(q, k.transpose(-2, -1)) * self.scale
x = self.atten_drop(self.softmax(x))
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_shape = x.shape[:2] + (self.n_heads * self.d_kv,)
x = x.reshape(*new_shape)
x = self.dense2(x)
x = self.dropout(x)
return x
class VanillaFFN(nn.Module):
"""FFN composed with two linear layers, also called MLP.
"""
def __init__(self,
d_model: int,
d_ff: int,
activation=None,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device())
act = nn.GELU() if activation is None else activation
dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device())
drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2
self.ffn = nn.Sequential(
dense1, act, drop1,dense2, drop2)
def forward(self, x):
return self.ffn(x)
class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=VanillaFFN,
num_experts=num_experts,
**moe_mlp_args(
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
))
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
TransformerLayer(
att=shared_sa,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
router=shared_router, experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
)
for i in range(depth)
]
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model,
num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.widenet(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x
from colossalai.context import ParallelMode
from colossalai.nn.layer import WrappedDropout as Dropout
def moe_sa_args(d_model: int,
n_heads: int,
d_kv: int,
attention_drop: float = 0,
drop_rate: float = 0,
bias: bool = True):
"""This is an example for args in moe self attention, since lots of modules should be
adapted before putting them in experts.
"""
dropout1 = Dropout(attention_drop, mode=ParallelMode.TENSOR)
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
return dict(
d_model=d_model,
n_heads=n_heads,
d_kv=d_kv,
bias=bias,
dropout1=dropout1,
dropout2=dropout2
)
def moe_mlp_args(d_model: int,
d_ff: int,
drop_rate: float,
bias: bool = True):
"""This is an example for args of MLP in Experts, since lots of modules should be adapted
before putting them in experts.
"""
dropout1 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
dropout2 = Dropout(drop_rate, mode=ParallelMode.TENSOR)
return dict(
d_model=d_model,
d_ff=d_ff,
bias=bias,
dropout1=dropout1,
dropout2=dropout2
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment