"vscode:/vscode.git/clone" did not exist on "494d2b7a072cdf838b8e4378be25cf4115c64d8f"
Unverified Commit 2286e85e authored by Rain Jiang's avatar Rain Jiang Committed by GitHub
Browse files

pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)


Co-authored-by: default avatarYichen Wang <yichen.wang@bytedance.com>
Co-authored-by: default avatarJinwu Guo <641876696@qq.com>
parent 91b3555d
......@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
k,
)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
cutlass_w4a8_moe_mm(
c1,
......@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
topk,
)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate)
intermediate_q = torch.empty(
......
......@@ -209,7 +209,7 @@ void cutlass_w4a8_group_gemm_caller(
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f;
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
;
......
import pytest
import torch
from sgl_kernel import cutlass_w4a8_moe_mm
from sgl_kernel import cutlass_w4a8_moe_mm, sgl_per_tensor_quant_fp8
from utils import is_hopper
......@@ -67,7 +67,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
if debug:
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(m, k, dtype=dtype, device=device)
......@@ -75,7 +74,6 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
......@@ -93,7 +91,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
s_strides = c_strides
# Quantize input
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
a_q, a_scale = _per_tensor_quant_fp8(a)
# Create output tensor
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
......@@ -117,7 +115,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Reference implementation
experts_selection_result = torch.full((m,), 0)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
......@@ -138,17 +136,29 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
raise
# @pytest.mark.skipif(
# not is_hopper(),
# reason="cutlass_w4a8_moe_mm is only supported on sm90",
# )
def _per_tensor_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn,
):
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
1,
device=x.device,
dtype=torch.float32,
)
sgl_per_tensor_quant_fp8(x, x_q, x_s, is_static=False)
return x_q, x_s
@pytest.mark.skipif(
True,
reason="TODO(rainj-me): fix cu129 binary issue on hopper cu126",
not is_hopper(),
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [256, 512, 1024])
@pytest.mark.parametrize("n", [1024, 2048, 7168])
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)
......@@ -163,7 +173,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
if debug:
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(batch_size, k, dtype=dtype, device=device)
......@@ -171,7 +180,6 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
......@@ -202,12 +210,8 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
# Permute input and quantize
a_perm = a[permutation]
a_q_perm = (
torch.clamp((a_perm / a_scale), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(device)
)
a_q, a_scale = _per_tensor_quant_fp8(a)
a_q_perm = a_q[permutation]
# Create stride tensors
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
......@@ -238,7 +242,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
c = c.to(dtype)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
c, a_q, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
......@@ -256,10 +260,11 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
raise
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
def ref_grouped_gemm(
c, a_q, a_scale, w, w_scale, num_experts, experts_selection_result
):
dtype = torch.bfloat16
c_ref = torch.zeros_like(c)
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
for i in range(num_experts):
token_idx = torch.where(experts_selection_result == i)[0]
if len(token_idx) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment