Unverified Commit bba5f289 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #64 from laekov/fp16_bal_loss

fix fp16 training with balance loss
parents 537679a8 ebefe2b1
......@@ -49,7 +49,7 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
def add_balance_log(model, writer, iteration):
from megatron import is_last_rank
if hasattr(model, 'module'):
while hasattr(model, 'module'):
model = model.module
balance_dict_tensor = torch.vstack(
......@@ -92,7 +92,7 @@ def patch_forward_step(forward_step_func):
return output
loss_name = args.balance_strategy + "_loss"
if hasattr(model, 'module'):
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment