"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "a210ec74d29ee718bca9b3c192e0a93cf86cbf21"
Unverified Commit 28bfe689 authored by Jiezhong Qiu's avatar Jiezhong Qiu Committed by GitHub
Browse files

Merge pull request #50 from laekov/fix-balance-loss

Fix grad of balance loss
parents 295a615a a12ad553
...@@ -94,15 +94,11 @@ def patch_forward_step(forward_step_func): ...@@ -94,15 +94,11 @@ def patch_forward_step(forward_step_func):
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers] loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = ( (loss, state_dict), bal_loss = (
output, output,
( torch.cat(loss_list).mean() * args.balance_loss_weight
torch.tensor(
loss_list, device=loss_list[0].device
).mean()
* args.balance_loss_weight
).float(),
) )
# avarage across world group # avarage across world group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment