Commit 913d7127 authored by Rick Ho's avatar Rick Ho
Browse files

use cat instead of creating new tensor

parent 295a615a
...@@ -97,12 +97,7 @@ def patch_forward_step(forward_step_func): ...@@ -97,12 +97,7 @@ def patch_forward_step(forward_step_func):
loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers] loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = ( (loss, state_dict), bal_loss = (
output, output,
( torch.cat(loss_list).mean() * args.balance_loss_weight
torch.tensor(
loss_list, device=loss_list[0].device
).mean()
* args.balance_loss_weight
).float(),
) )
# avarage across world group # avarage across world group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment