Unverified Commit 2286e85e authored by Rain Jiang's avatar Rain Jiang Committed by GitHub
Browse files

pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)


Co-authored-by: default avatarYichen Wang <yichen.wang@bytedance.com>
Co-authored-by: default avatarJinwu Guo <641876696@qq.com>
parent 91b3555d
...@@ -147,8 +147,8 @@ def cutlass_w4a8_moe( ...@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
k, k,
) )
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half) c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
cutlass_w4a8_moe_mm( cutlass_w4a8_moe_mm(
c1, c1,
...@@ -166,7 +166,7 @@ def cutlass_w4a8_moe( ...@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
topk, topk,
) )
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate) silu_and_mul(c1, intermediate)
intermediate_q = torch.empty( intermediate_q = torch.empty(
......
...@@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller( ...@@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller(
Args arguments; Args arguments;
decltype(arguments.epilogue.thread) fusion_args; decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f; fusion_args.alpha = 0;
fusion_args.beta = 0; fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>(); fusion_args.alpha_ptr = a_scales.data_ptr<float>();
; ;
......
import pytest import pytest
import torch import torch
from sgl_kernel import cutlass_w4a8_moe_mm from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
from utils import is_hopper from utils import is_hopper
...@@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
if debug: if debug:
a = torch.ones(m, k, dtype=torch.bfloat16, device=device) a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else: else:
a = torch.randn(m, k, dtype=dtype, device=device) a = torch.randn(m, k, dtype=dtype, device=device)
...@@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device -8, 8, (num_experts, n, k), dtype=torch.int8, device=device
) )
affine_coeff = 0.005 affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = ( ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff * affine_coeff
...@@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
s_strides = c_strides s_strides = c_strides
# Quantize input # Quantize input
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device) a_q, a_scale = _per_tensor_quant_fp8(a)
# Create output tensor # Create output tensor
c = torch.empty((m, n), dtype=torch.bfloat16, device=device) c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
...@@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Reference implementation # Reference implementation
experts_selection_result = torch.full((m,), 0) experts_selection_result = torch.full((m,), 0)
c_ref = ref_grouped_gemm( c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
) )
# Compare results # Compare results
...@@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
raise raise
# @pytest.mark.skipif( def _per_tensor_quant_fp8(
# not is_hopper(), x: torch.Tensor,
# reason="cutlass_w4a8_moe_mm is only supported on sm90", dtype: torch.dtype = torch.float8_e4m3fn,
# ) ):
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
1,
device=x.device,
dtype=torch.float32,
)
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
return x_q, x_s
@pytest.mark.skipif( @pytest.mark.skipif(
True, not is_hopper(),
reason="TODO(rainj-me): fix cu129 binary issue on hopper cu126", reason="cutlass_w4a8_moe_mm is only supported on sm90",
) )
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
@pytest.mark.parametrize("k", [256, 512, 1024]) @pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
@pytest.mark.parametrize("n", [1024, 2048, 7168]) @pytest.mark.parametrize("n", [256, 512, 1024, 2048])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
if debug: if debug:
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device) a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device) ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device) ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else: else:
a = torch.randn(batch_size, k, dtype=dtype, device=device) a = torch.randn(batch_size, k, dtype=dtype, device=device)
...@@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device -8, 8, (num_experts, n, k), dtype=torch.int8, device=device
) )
affine_coeff = 0.005 affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = ( ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device) torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff * affine_coeff
...@@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device) expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
# Permute input and quantize # Permute input and quantize
a_perm = a[permutation] a_q, a_scale = _per_tensor_quant_fp8(a)
a_q_perm = ( a_q_perm = a_q[permutation]
torch.clamp((a_perm / a_scale), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(device)
)
# Create stride tensors # Create stride tensors
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64) a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
...@@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
c = c.to(dtype) c = c.to(dtype)
c_ref = ref_grouped_gemm( c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
) )
# Compare results # Compare results
...@@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): ...@@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
raise raise
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result): def ref_grouped_gemm(
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
):
dtype = torch.bfloat16 dtype = torch.bfloat16
c_ref = torch.zeros_like(c) c_ref = torch.zeros_like(c)
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
for i in range(num_experts): for i in range(num_experts):
token_idx = torch.where(experts_selection_result == i)[0] token_idx = torch.where(experts_selection_result == i)[0]
if len(token_idx) == 0: if len(token_idx) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment