Commit b97483a4 authored by Rick Ho's avatar Rick Ho
Browse files

megatron support for specific expert parallelism

parent 8f67b530
...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module): ...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module):
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size())) weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device) self.weight.data = torch.Tensor(weight, dtype=dtype, device=device)
if self.bias is not None: if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
...@@ -143,7 +143,7 @@ class FMoE(nn.Module): ...@@ -143,7 +143,7 @@ class FMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None: if expert is not None:
self.experts = nn.ModuleList([expert(d_model) self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)]) for _ in range(num_expert)])
self.experts_fused = False self.experts_fused = False
else: else:
......
...@@ -26,7 +26,8 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -26,7 +26,8 @@ class MegatronMLP(FMoETransformerMLP):
super().__init__(args.num_experts, super().__init__(args.num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size, d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group) world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
self.bias = torch.nn.parameter.Parameter( self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32) torch.zeros(args.hidden_size, dtype=torch.float32)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment