Commit 27c89b5a authored by Rick Ho's avatar Rick Ho
Browse files

customized dp-comm and hidden-hidden-size

parent 87dad9d5
...@@ -157,17 +157,14 @@ class FMoE(nn.Module): ...@@ -157,17 +157,14 @@ class FMoE(nn.Module):
base_idx += batch_size base_idx += batch_size
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
def mark_parallel_comm(self): def mark_parallel_comm(self, expert_dp_comm='none'):
r''' r'''
Automatically mark the data parallel comms of the parameters within the Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function module. This can be typically called at the end of the __init__ function
in child classes. in child classes.
''' '''
if self.experts is not None: if self.experts is not None:
if self.world_size > self.mp_size: comm = expert_dp_comm
comm = 'none'
else:
comm = 'dp'
if isinstance(self.experts, list): if isinstance(self.experts, list):
for e in self.experts: for e in self.experts:
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
......
...@@ -24,7 +24,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -24,7 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else: else:
world_size = args.world_size world_size = args.world_size
super().__init__(args.num_experts, super().__init__(args.num_experts,
d_model=args.hidden_size, d_hidden=args.hidden_size * 4, d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group) world_size=world_size, mp_group=group)
self.bias = torch.nn.parameter.Parameter( self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32) torch.zeros(args.hidden_size, dtype=torch.float32)
...@@ -34,7 +34,8 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -34,7 +34,8 @@ class MegatronMLP(FMoETransformerMLP):
return super().forward(inp), self.bias return super().forward(inp), self.bias
def fmoefy(model, num_experts=None, distributed_experts=True): def fmoefy(model, num_experts=None, distributed_experts=True,
hidden_hidden_size=None):
r''' r'''
Replace MLP layers in a transformer-based model in Megatron by MoE. Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has * `model` should be a standard Megatron model that has
...@@ -57,6 +58,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True): ...@@ -57,6 +58,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
'num_experts' in args 'num_experts' in args
), 'num_experts should be specified in arguments or fmoefy function' ), 'num_experts should be specified in arguments or fmoefy function'
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, 'hidden_hidden_size'):
args.hidden_hidden_size = args.hidden_size * 4
# Set distributed_experts to None to use default setting in args # Set distributed_experts to None to use default setting in args
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
......
...@@ -47,7 +47,8 @@ class FMoETransformerMLP(FMoE): ...@@ -47,7 +47,8 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
gate=NaiveGate, gate=NaiveGate,
top_k=2, top_k=2,
pre_lnorm=False pre_lnorm=False,
expert_dp_comm='none'
): ):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group) top_k=top_k, world_size=world_size, mp_group=mp_group)
...@@ -55,7 +56,7 @@ class FMoETransformerMLP(FMoE): ...@@ -55,7 +56,7 @@ class FMoETransformerMLP(FMoE):
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm() self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
r''' r'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment