Commit b97483a4 authored by Rick Ho's avatar Rick Ho
Browse files

megatron support for specific expert parallelism

parent 8f67b530
......@@ -47,7 +47,7 @@ class FMoELinear(nn.Module):
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
self.weight.data = torch.Tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
......
......@@ -26,7 +26,8 @@ class MegatronMLP(FMoETransformerMLP):
super().__init__(args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group)
world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment