import itertools import unittest import torch import torch.nn as nn # TODO: use interface in cpu.py from sgl_kernel.common_ops import ( convert_weight_packed, fp8_scaled_mm_cpu, int8_scaled_mm_cpu, int8_scaled_mm_with_quant, per_token_quant_int8_cpu, weight_packed_linear, ) from utils import ( convert_weight, native_w8a8_per_token_matmul, per_token_quant_int8, precision, ) from sglang.test.test_utils import CustomTestCase class Mod(nn.Module): def __init__(self, input_channel, output_channel, has_bias): super(Mod, self).__init__() self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) def forward(self, x): return self.linear(x) class TestGemm(CustomTestCase): M = [1, 101] N = [32 * 13] K = [32 * 16] has_bias = [False, True] M_int8 = [2, 128] N_int8 = [32 * 12] K_int8 = [32 * 17] M_fp8 = [1, 11] N_fp8 = [128, 224] K_fp8 = [512, 576] def _bf16_gemm(self, M, N, K, has_bias): mat1 = torch.randn(M, K, dtype=torch.bfloat16) mat2 = torch.randn(N, K, dtype=torch.bfloat16) ref = torch.matmul(mat1.float(), mat2.float().t()) if has_bias: bias = torch.randn(N, dtype=torch.float32) ref.add_(bias.bfloat16()) ref = ref.bfloat16() out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False) packed_mat2 = convert_weight_packed(mat2) out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True) atol = rtol = precision[ref.dtype] self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol)) def test_bf16_gemm(self): for params in itertools.product( self.M, self.N, self.K, self.has_bias, ): with self.subTest( M=params[0], N=params[1], K=params[2], has_bias=params[3], ): self._bf16_gemm(*params) def _int8_gemm(self, M, N, K, has_bias): dtype = torch.bfloat16 A = torch.randn((M, K), dtype=dtype) / 10 Aq, As = per_token_quant_int8(A) factor_for_scale = 1e-2 int8_max = 127 int8_min = -128 B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2 Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) Bs = torch.rand(N) * factor_for_scale bias = torch.randn(N) if has_bias else None ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype) atol = rtol = precision[ref_out.dtype] Aq2, As2 = per_token_quant_int8_cpu(A) out = int8_scaled_mm_cpu( Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False ) self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) # test the fused version fused_out = int8_scaled_mm_with_quant( A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False ) self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) def test_int8_gemm(self): for params in itertools.product( self.M_int8, self.N_int8, self.K_int8, self.has_bias, ): with self.subTest( M=params[0], N=params[1], K=params[2], has_bias=params[3], ): self._int8_gemm(*params) def _fp8_gemm(self, M, N, K, has_bias): prepack = True chunk = False scale_block_size_N = 64 scale_block_size_K = 128 assert scale_block_size_N <= N assert scale_block_size_K <= K A_dtype = torch.bfloat16 model = Mod(K, N, has_bias).eval() if chunk: data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K) else: data = torch.randn(M, K, dtype=A_dtype) weight = model.linear.weight # (N, K) if has_bias: bias = model.linear.bias fp8_weight, scales, dq_weight = convert_weight( weight, [scale_block_size_N, scale_block_size_K], A_dtype ) if has_bias: ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype) else: ref = torch.matmul(data.to(A_dtype), dq_weight.T) if prepack: fp8_weight = convert_weight_packed(fp8_weight) opt = fp8_scaled_mm_cpu( data, fp8_weight, scales, [scale_block_size_N, scale_block_size_K], bias if has_bias else None, data.dtype, prepack, ) atol = rtol = precision[ref.dtype] self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol)) def test_fp8_gemm(self): for params in itertools.product( self.M_fp8, self.N_fp8, self.K_fp8, self.has_bias, ): with self.subTest( M=params[0], N=params[1], K=params[2], has_bias=params[3], ): self._fp8_gemm(*params) if __name__ == "__main__": unittest.main()