Commit 57bdfe88 authored by Rick Ho's avatar Rick Ho
Browse files

fix megatron adapter for swipe

parent 8a56481b
...@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate ...@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate from .gshard_gate import GShardGate
from .switch_gate import SwitchGate from .switch_gate import SwitchGate
from .swipe_gate import SwipeGate
...@@ -23,3 +23,7 @@ class BaseGate(nn.Module): ...@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
if clear: if clear:
self.loss = None self.loss = None
return loss return loss
@property
def has_loss(self):
return self.loss is not None
...@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration): ...@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
while hasattr(model, 'module'): while hasattr(model, 'module'):
model = model.module model = model.module
balance_dict_tensor = torch.vstack( losses = [l.mlp.gate.get_loss(clear=True)
[l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers] for l in model.language_model.transformer.layers
).detach() if l.mlp.gate.has_loss]
if len(losses) == 0:
return
balance_dict_tensor = torch.vstack(losses).detach()
world_group = get_torch_default_comm() world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group) world_size = torch.distributed.get_world_size(group=world_group)
torch.distributed.all_reduce(balance_dict_tensor, group=world_group) torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
......
...@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
elif args.balance_strategy == "switch": elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate from fmoe.gates import SwitchGate
gate = SwitchGate gate = SwitchGate
elif args.balance_strategy == "swipe":
from fmoe.gates import SwipeGate
gate = SwipeGate
elif gate is None: elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
......
...@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func): ...@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
args = get_args() args = get_args()
output = forward_step_func(data_iterator, model, input_tensor) output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive': if not is_pipeline_last_stage() or not args.balance_strategy:
return output return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'): while hasattr(model, 'module'):
model = model.module model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1) loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers] for l in model.language_model.transformer.layers
if l.mlp.gate.has_loss]
if len(loss_list) == 0:
return output
loss_name = args.balance_strategy + "_loss"
(loss, state_dict), bal_loss = ( (loss, state_dict), bal_loss = (
output, output,
torch.cat(loss_list).mean() * args.balance_loss_weight torch.cat(loss_list).mean() * args.balance_loss_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment