"...text-generation-inference.git" did not exist on "f91e9d282d73e09cdb876924412f2ed66212d736"
Unverified Commit bba5f289 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #64 from laekov/fp16_bal_loss

fix fp16 training with balance loss
parents 537679a8 ebefe2b1
...@@ -49,7 +49,7 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global): ...@@ -49,7 +49,7 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
def add_balance_log(model, writer, iteration): def add_balance_log(model, writer, iteration):
from megatron import is_last_rank from megatron import is_last_rank
if hasattr(model, 'module'): while hasattr(model, 'module'):
model = model.module model = model.module
balance_dict_tensor = torch.vstack( balance_dict_tensor = torch.vstack(
...@@ -92,7 +92,7 @@ def patch_forward_step(forward_step_func): ...@@ -92,7 +92,7 @@ def patch_forward_step(forward_step_func):
return output return output
loss_name = args.balance_strategy + "_loss" loss_name = args.balance_strategy + "_loss"
if hasattr(model, 'module'): while hasattr(model, 'module'):
model = model.module model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1) loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment