"vscode:/vscode.git/clone" did not exist on "f84537e0e05f091736c6be2dff8622b01a8c973d"
Unverified Commit 680c610f authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Hot-fix-mixstral-loss (#27948)

* fix loss computation

* compute on GPU if possible
parent 4b759da8
......@@ -95,7 +95,8 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
if isinstance(gate_logits, tuple):
# cat along the layers?
gate_logits = torch.cat(gate_logits, dim=0)
compute_device = gate_logits[0].device
gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
routing_weights = routing_weights.softmax(dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment