Commit d83234b0 authored by Rick Ho's avatar Rick Ho
Browse files

use parallel label in gate

parent 67c667f2
...@@ -29,20 +29,20 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -29,20 +29,20 @@ class DistributedGroupedDataParallel(nn.Module):
for p in self.module.parameters(): for p in self.module.parameters():
if not p.requires_grad or p.grad is None: if not p.requires_grad or p.grad is None:
continue continue
if hasattr(p, 'parallel_method'): if hasattr(p, 'dp_comm'):
pm = p.parallel_method dp_comm = p.dp_comm
else: else:
pm = 'dp' dp_comm = 'dp'
group_key = (pm, p.dtype) group_key = (dp_comm, p.dtype)
if group_key not in groups: if group_key not in groups:
groups[group_key] = [p] groups[group_key] = [p]
else: else:
groups[group_key].append(p) groups[group_key].append(p)
for pm, dtype in groups: for dp_comm, dtype in groups:
if pm not in self.comms: if dp_comm not in self.comms:
continue continue
group = groups[pm, dtype] group = groups[dp_comm, dtype]
comm = self.comms[pm] comm = self.comms[dp_comm]
grads = [p.grad.data for p in group] grads = [p.grad.data for p in group]
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce and dtype != torch.float32: if fp32_allreduce and dtype != torch.float32:
......
...@@ -92,6 +92,8 @@ class FMoETransformerMLP(nn.Module): ...@@ -92,6 +92,8 @@ class FMoETransformerMLP(nn.Module):
self.h4toh = FMoELinear(num_expert, d_hidden, d_model) self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k) self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
for p in self.gate.parameters():
setattr(p, 'dp_comm', 'world')
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter( self.bias = torch.nn.parameter.Parameter(
......
...@@ -18,6 +18,8 @@ def create_moe_mlp(args, model_parallel_rank, group): ...@@ -18,6 +18,8 @@ def create_moe_mlp(args, model_parallel_rank, group):
model_parallel_rank=model_parallel_rank, model_parallel_rank=model_parallel_rank,
mp_group=group, mp_group=group,
) )
for p in fmoe.gate.parameters():
setattr(p, 'shared', True)
return fmoe return fmoe
......
...@@ -29,7 +29,7 @@ if __name__ == '__main__': ...@@ -29,7 +29,7 @@ if __name__ == '__main__':
} }
) )
], ],
version='0.0.1', version='0.0.2',
cmdclass={ cmdclass={
'build_ext': BuildExtension 'build_ext': BuildExtension
}) })
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment