Commit b97483a4 authored by Rick Ho's avatar Rick Ho
Browse files

megatron support for specific expert parallelism

parent 8f67b530
...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module): ...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module):
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size())) weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device) self.weight.data = torch.Tensor(weight, dtype=dtype, device=device)
if self.bias is not None: if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
......
...@@ -26,7 +26,8 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -26,7 +26,8 @@ class MegatronMLP(FMoETransformerMLP):
super().__init__(args.num_experts, super().__init__(args.num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size, d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group) world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
self.bias = torch.nn.parameter.Parameter( self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32) torch.zeros(args.hidden_size, dtype=torch.float32)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment