"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "33801c1597edca5dd04c5de117db89b5bc27f43a"
Unverified Commit ee0b3c5b authored by Yuhao Yao's avatar Yuhao Yao Committed by GitHub
Browse files

[1/N][Bug] Fix w4afp8 MoE NaN issue (sgl-kernel, fixed) (#10108)

parent 6049ca20
...@@ -41,8 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type ...@@ -41,8 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementC = cutlass::half_t; // Default output type (FP16) using ElementC = cutlass::bfloat16_t; // Output type
using ElementD = ElementC; // Default output type (FP16) using ElementD = ElementC; // Output type
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations // Architecture-specific configurations
......
...@@ -96,7 +96,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -96,7 +96,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device) a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
# Create output tensor # Create output tensor
c = torch.empty((m, n), dtype=torch.float16, device=device) c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm( cutlass_w4a8_moe_mm(
c, c,
a_q, a_q,
...@@ -211,7 +211,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -211,7 +211,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
b_strides = a_strides b_strides = a_strides
s_strides = c_strides s_strides = c_strides
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device) c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm( cutlass_w4a8_moe_mm(
c_perm, c_perm,
a_q_perm, a_q_perm,
...@@ -262,10 +262,9 @@ def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_r ...@@ -262,10 +262,9 @@ def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_r
continue continue
a = a_q[token_idx] a = a_q[token_idx]
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float) ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32)
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype) ref_w = w[i].to(torch.float32) * ref_w_scale_repeat
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale
c = c.to(dtype)
c_ref[token_idx] = c.to(dtype) c_ref[token_idx] = c.to(dtype)
return c_ref return c_ref
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment