Unverified Commit 39ef3fb2 authored by Siming Dai's avatar Siming Dai Committed by GitHub
Browse files

[Mixtral] Fixes attention masking in the loss (#29363)



Fix mixtral load balancing loss
Co-authored-by: default avatardingkunbo <dingkunbo@baidu.com>
parent 38953a75
......@@ -123,8 +123,8 @@ def load_balancing_loss_func(
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
.reshape(-1, 2, num_experts)
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment