Commit d83234b0 authored by Rick Ho's avatar Rick Ho
Browse files

use parallel label in gate

parent 67c667f2
......@@ -29,20 +29,20 @@ class DistributedGroupedDataParallel(nn.Module):
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, 'parallel_method'):
pm = p.parallel_method
if hasattr(p, 'dp_comm'):
dp_comm = p.dp_comm
else:
pm = 'dp'
group_key = (pm, p.dtype)
dp_comm = 'dp'
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
else:
groups[group_key].append(p)
for pm, dtype in groups:
if pm not in self.comms:
for dp_comm, dtype in groups:
if dp_comm not in self.comms:
continue
group = groups[pm, dtype]
comm = self.comms[pm]
group = groups[dp_comm, dtype]
comm = self.comms[dp_comm]
grads = [p.grad.data for p in group]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce and dtype != torch.float32:
......
......@@ -92,6 +92,8 @@ class FMoETransformerMLP(nn.Module):
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
for p in self.gate.parameters():
setattr(p, 'dp_comm', 'world')
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(
......
......@@ -18,6 +18,8 @@ def create_moe_mlp(args, model_parallel_rank, group):
model_parallel_rank=model_parallel_rank,
mp_group=group,
)
for p in fmoe.gate.parameters():
setattr(p, 'shared', True)
return fmoe
......
......@@ -29,7 +29,7 @@ if __name__ == '__main__':
}
)
],
version='0.0.1',
version='0.0.2',
cmdclass={
'build_ext': BuildExtension
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment