Unverified Commit 4068e012 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Fix per token fp8 quant precision (#4362)

parent 817d4370
...@@ -22,10 +22,9 @@ def vllm_per_token_quant_fp8( ...@@ -22,10 +22,9 @@ def vllm_per_token_quant_fp8(
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_) output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale) sgl_per_token_quant_fp8(input, output, scale)
return output, scale return output, scale
...@@ -37,9 +36,6 @@ def calculate_diff(batch_size: int, seq_len: int): ...@@ -37,9 +36,6 @@ def calculate_diff(batch_size: int, seq_len: int):
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose( if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
......
...@@ -49,8 +49,6 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -49,8 +49,6 @@ __global__ void per_token_quant_fp8_kernel(
} }
__syncthreads(); __syncthreads();
const float scale_val = 1.0f / block_max;
// Quantize using vectorized loads // Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) { for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec; vec_t input_vec;
...@@ -59,7 +57,7 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -59,7 +57,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE output_arr[vec_size]; FP8_TYPE output_arr[vec_size];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) { for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); float val = fmaxf(fminf(static_cast<float>(input_vec[j]) / block_max, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val); output_arr[j] = static_cast<FP8_TYPE>(val);
#else #else
......
...@@ -21,18 +21,16 @@ def vllm_per_token_quant_fp8( ...@@ -21,18 +21,16 @@ def vllm_per_token_quant_fp8(
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32) scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_) output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale) sgl_per_token_quant_fp8(input, output, scale)
scale = scale.reshape(-1, 1)
return output, scale return output, scale
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens,hidden_dim", "num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])), list(itertools.product([32, 64, 128, 256, 512], [128, 256, 512, 2048, 4096])),
) )
def test_per_token_quant_compare_implementations( def test_per_token_quant_compare_implementations(
num_tokens: int, num_tokens: int,
...@@ -44,7 +42,7 @@ def test_per_token_quant_compare_implementations( ...@@ -44,7 +42,7 @@ def test_per_token_quant_compare_implementations(
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5)
torch.testing.assert_close( torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment