Unverified Commit 465968b2 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix dtype error in CI (#8197)

parent 750838ad
......@@ -524,7 +524,7 @@ def biased_grouped_topk_gpu(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output,
gating_output.to(dtype=torch.float32),
correction_bias,
topk_weights,
topk_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment