"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "54229cd33e147b883827158b006e80723a6a03e0"
Unverified Commit f7f22487 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
parent 38c68b5b
...@@ -24,6 +24,7 @@ class MoeExperts(nn.Module): ...@@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts): class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the """A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert moe model parallel group, where E is the number of experts. Every expert
...@@ -35,7 +36,6 @@ class Experts(MoeExperts): ...@@ -35,7 +36,6 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class expert_args: Args used to initialize experts, the args could be found in corresponding expert class
""" """
@no_shard_zero_decrator(is_replicated=False)
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts) super().__init__("all_to_all", num_experts)
......
...@@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module): ...@@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
return F.linear(x, self.weight) return F.linear(x, self.weight)
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module): class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits """A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across to router all tokens, is mainly used to exchange all tokens for every expert across
...@@ -241,12 +242,11 @@ class MoeLayer(nn.Module): ...@@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
""" """
@no_shard_zero_decrator(is_replicated=True)
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
self.num_experts = num_experts self.num_experts = num_experts
self.gate = FP32LinearGate(dim_model, num_experts) self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router = router self.router = router
self.experts = experts self.experts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
...@@ -254,16 +254,14 @@ class MoeLayer(nn.Module): ...@@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
self.ep_size = experts.dist_info.ep_size self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor): def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group) expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape) expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group) expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output return expert_output
...@@ -274,16 +272,22 @@ class MoeLayer(nn.Module): ...@@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
return expert_out return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
gate_output = self.gate(fp32_input) # the data type of the inputs in the gating should be fp32
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel: if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else: else:
sec_mask_f = router_res[1].type_as(inputs) sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h] # dispatch_data [e, c, h]
...@@ -295,12 +299,11 @@ class MoeLayer(nn.Module): ...@@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.") "build function.")
# expert_output [e, c, h] # expert_output [e, c, h]
if self.use_kernel: if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model) expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res) ans = MoeCombine.apply(expert_output, *route_result_list)
else: else:
combine_weights = router_res[0].type_as(inputs) combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1) combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1]) expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
......
...@@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True): ...@@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
def _no_shard(*args, **kwargs): def _no_shard(*args, **kwargs):
with no_shard_zero_context(is_replicated): with no_shard_zero_context(is_replicated):
init_func(*args, **kwargs) ret = init_func(*args, **kwargs)
return ret
return _no_shard return _no_shard
......
...@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = layer.to(get_current_device())
if data_type == torch.float16: if data_type == torch.float16:
layer = layer.half() layer = layer.half()
...@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results # save all results
o_tk_grad = tokens.grad.data.clone() o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone() o_gt_grad = layer.gate_weight.grad.data.clone()
# reset all gradients # reset all gradients
tokens.grad.zero_() tokens.grad.zero_()
layer.gate.weight.grad.zero_() layer.gate_weight.grad.zero_()
layer.use_kernel = True layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel new_out = layer(tokens) # get ouputs through colossal kernel
...@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out.backward(grad) # get new type gradient new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone() n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone() n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32: if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad) check_equal(o_tk_grad, n_tk_grad)
......
...@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): ...@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
assert hasattr(param, 'colo_attr') assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded # the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
else: else:
assert param.colo_attr.sharded_data_tensor.is_sharded assert param.colo_attr.sharded_data_tensor.is_sharded
......
...@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload, ...@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model) apex_grad_handler = MoeGradientHandler(model)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.data_payload)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 5: if i > 5:
break break
......
...@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= ...@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
else: else:
zero_p = zero_p.colo_attr.data_payload.to(p.device) zero_p = zero_p.colo_attr.data_payload.to(p.device)
assert p.dtype == zero_p.dtype assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment