"gallery/v2_transforms/plot_transforms_v2.py" did not exist on "a27522cb06d23bb8949c69dadb8e9d0f24b48c00"
Unverified Commit bba5f289 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #64 from laekov/fp16_bal_loss

fix fp16 training with balance loss
parents 537679a8 ebefe2b1
......@@ -49,7 +49,7 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
def add_balance_log(model, writer, iteration):
from megatron import is_last_rank
if hasattr(model, 'module'):
while hasattr(model, 'module'):
model = model.module
balance_dict_tensor = torch.vstack(
......@@ -92,7 +92,7 @@ def patch_forward_step(forward_step_func):
return output
loss_name = args.balance_strategy + "_loss"
if hasattr(model, 'module'):
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment