Commit 26b999c7 authored by Michael Goin's avatar Michael Goin Committed by simon-mo
Browse files

[CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe (#24750)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent da3fa78d
......@@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
......
......@@ -771,11 +771,11 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size)
num_experts, 2 * intermediate_size, hidden_size).to(device)
w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size)
num_experts, hidden_size, intermediate_size).to(device)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment