Commit 82023779 authored by 1SAA's avatar 1SAA Committed by Frank Lee
Browse files

Added TPExpert for special situation

parent 36b84772
from .experts import Experts, FFNExperts
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator
from .utils import NormalNoiseGenerator, build_ffn_experts
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'build_ffn_experts'
]
......@@ -15,6 +15,55 @@ except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
......
......@@ -7,7 +7,16 @@ from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
class Experts(nn.Module):
class MoeExperts(nn.Module):
def __init__(self, comm: str):
super().__init__()
assert comm in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm = comm
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
......@@ -20,19 +29,22 @@ class Experts(nn.Module):
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', True)
self.num_local_experts = num_local_experts
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
......@@ -44,10 +56,10 @@ class Experts(nn.Module):
return output
class FFNExperts(nn.Module):
class FFNExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__()
super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
......@@ -75,7 +87,7 @@ class FFNExperts(nn.Module):
for param in self.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h]
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
......@@ -96,3 +108,58 @@ class FFNExperts(nn.Module):
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather")
assert d_ff % moe_env.model_parallel_size == 0, \
"d_ff should be divied by moe model size"
p_ff = d_ff // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.MOE_MODEL):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_param', True)
self.w2.__setattr__('moe_param', True)
self.b1.__setattr__('moe_param', True)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]
......@@ -8,7 +8,8 @@ from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import autocast_softmax
......@@ -198,7 +199,7 @@ class MoeLayer(nn.Module):
:type experts: nn.Module
"""
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
......@@ -207,8 +208,8 @@ class MoeLayer(nn.Module):
self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def expert_part(self, expert_input: torch.Tensor):
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
input_shape = expert_input.shape
......@@ -221,24 +222,42 @@ class MoeLayer(nn.Module):
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode)
if self.cuda_mode:
logits, mask, dest_idx, ec = router_res
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = router_res[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
if self.experts.comm == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.cuda_mode:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res)
else:
combine_weights, sec_mask = router_res
sec_mask_f = sec_mask.type_as(inputs)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input)
combine_weights = router_res[0]
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ans = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
return ret
ans = ans.reshape(inputs.shape)
return ans
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
from colossalai.global_variables import moe_env
from .experts import FFNExperts, TPExperts
class NormalNoiseGenerator:
......@@ -14,10 +16,9 @@ class NormalNoiseGenerator:
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2,
device=get_current_device())).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
......@@ -30,3 +31,13 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
moe_mp_size = moe_env.model_parallel_size
if num_experts % moe_mp_size == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % moe_mp_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")
......@@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
......@@ -110,7 +110,7 @@ class Widenet(nn.Module):
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
......@@ -177,7 +177,7 @@ class ViTMoE(nn.Module):
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),
......
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
......@@ -9,10 +7,10 @@ import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
from colossalai.context.random import moe_set_seed
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
......@@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
moe_set_seed(42)
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
expert = Experts(nn.Identity, 4)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
......@@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment