Unverified Commit 465968b2 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix dtype error in CI (#8197)

parent 750838ad
...@@ -524,7 +524,7 @@ def biased_grouped_topk_gpu( ...@@ -524,7 +524,7 @@ def biased_grouped_topk_gpu(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk( aiter_biased_grouped_topk(
gating_output, gating_output.to(dtype=torch.float32),
correction_bias, correction_bias,
topk_weights, topk_weights,
topk_ids, topk_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment