"src/diffusers/models/autoencoders/vae.py" did not exist on "878af0e113a0f730557f627c24de3bc056916a2d"
Commit 27c89b5a authored by Rick Ho's avatar Rick Ho
Browse files

customized dp-comm and hidden-hidden-size

parent 87dad9d5
......@@ -157,17 +157,14 @@ class FMoE(nn.Module):
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self):
def mark_parallel_comm(self, expert_dp_comm='none'):
r'''
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
'''
if self.experts is not None:
if self.world_size > self.mp_size:
comm = 'none'
else:
comm = 'dp'
comm = expert_dp_comm
if isinstance(self.experts, list):
for e in self.experts:
mark_module_parallel_comm(e, comm)
......
......@@ -24,7 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else:
world_size = args.world_size
super().__init__(args.num_experts,
d_model=args.hidden_size, d_hidden=args.hidden_size * 4,
d_model=args.hidden_size, d_hidden=args.hidden_hidden_size,
world_size=world_size, mp_group=group)
self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32)
......@@ -34,7 +34,8 @@ class MegatronMLP(FMoETransformerMLP):
return super().forward(inp), self.bias
def fmoefy(model, num_experts=None, distributed_experts=True):
def fmoefy(model, num_experts=None, distributed_experts=True,
hidden_hidden_size=None):
r'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
......@@ -57,6 +58,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
'num_experts' in args
), 'num_experts should be specified in arguments or fmoefy function'
if hidden_hidden_size is not None:
args.hidden_hidden_size = hidden_hidden_size
elif not hasattr(args, 'hidden_hidden_size'):
args.hidden_hidden_size = args.hidden_size * 4
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
......
......@@ -47,7 +47,8 @@ class FMoETransformerMLP(FMoE):
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
pre_lnorm=False
pre_lnorm=False,
expert_dp_comm='none'
):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group)
......@@ -55,7 +56,7 @@ class FMoETransformerMLP(FMoE):
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm()
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
r'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment