Commit a12ad553 authored by Rick Ho's avatar Rick Ho
Browse files

fix concat shape

parent 913d7127
......@@ -94,7 +94,8 @@ def patch_forward_step(forward_step_func):
if hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers]
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment