Unverified Commit 620e8924 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[Bugfix] [Tests] Enforce `out` tensor device in `kernel/moe/test_cutedsl_moe.py` (#39644)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent f00c5539
......@@ -142,7 +142,9 @@ def prepare_inputs(
# Initialize the hidden_states_3d with ones instead of empty to avoid nan
# issue.
hidden_states_3d = torch.ones(
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
(num_experts, max(masked_m), hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
for i in range(num_experts):
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
......@@ -426,7 +428,7 @@ def test_flashinfer_cutedsl_moe_masked(
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
out = torch.empty_like(hidden_states_3d)
out = torch.empty_like(hidden_states_3d, device=hidden_states.device)
# Note: the 1st dim shouldn't be bs
wk = torch.empty(
num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment