"vscode:/vscode.git/clone" did not exist on "e0aa838971266abef1b58985bef6f092f4ab72df"
Unverified Commit db7343c9 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

fix per token cuda kernel hidden dim cannot divide by 16 (#8543)

parent 533cb5b2
...@@ -12,6 +12,39 @@ from sglang.srt.utils import is_hip ...@@ -12,6 +12,39 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value
if _is_hip:
FP8_E4M3_MAX = 224.0 # ROCM uses 224.0
else:
# For CUDA, get the actual max value from the type
FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)
def torch_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
device = input.device
dtype = input.dtype
# Find max absolute value per token (row) - exactly like CUDA kernel
max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens]
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
scales = max_vals / FP8_E4M3_MAX # [num_tokens]
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
scale_inv = 1.0 / scales # [num_tokens]
# Quantize: input * scale_inv, then clamp to FP8 range
quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv
quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)
# Convert to FP8 - use more explicit conversion
quantized_fp8 = quantized_float.to(fp8_type_)
return quantized_fp8, scales
def vllm_per_token_quant_fp8( def vllm_per_token_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
...@@ -29,53 +62,100 @@ def sglang_per_token_quant_fp8( ...@@ -29,53 +62,100 @@ def sglang_per_token_quant_fp8(
return output, scale return output, scale
def calculate_diff(batch_size: int, seq_len: int): def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
"""Calculate difference between VLLM and SGLang implementations.""" """Compare Torch reference, VLLM, and SGLang implementations."""
device = torch.device("cuda") device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) x = torch.rand(
(batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
)
# Get all three implementations
torch_out, torch_scale = torch_per_token_quant_fp8(x)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item() print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose( # Compare scales
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
print("✅ All implementations match") vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
else:
print("❌ Implementations differ") print(f"Scale differences:")
print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}")
# Compare outputs
torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
vllm_sglang_out_diff = (
torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
)
print(f"Output differences:")
print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}")
# Check tolerances
rtol, atol = 1e-3, 1e-5
torch_vllm_match = torch.allclose(
torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
torch_sglang_match = torch.allclose(
torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)
if hidden_dim == 1368:
rtol = 1e-2
# we found vllm sglang has diff when hidden dim is not dividable by 16
# and we believe SGLang is closer to Torch implementation
vllm_sglang_match = torch.allclose(
vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)
print(f"Matches (rtol={rtol}, atol={atol}):")
print(f" Torch vs VLLM: {'✅' if torch_vllm_match else '❌'}")
print(f" Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}")
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
batch_size_range = [16, 32, 64, 128] batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range)) configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len"], x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sglang"], line_vals=["torch", "vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"], line_names=["Torch Reference", "VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance", plot_name="per-token-dynamic-quant-fp8-performance",
args={}, args={},
) )
) )
def benchmark_quantization(batch_size, seq_len, provider): def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
dtype = torch.float16 dtype = torch.float16
device = torch.device("cuda") device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "vllm": if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm":
fn = lambda: vllm_per_token_quant_fp8(x.clone()) fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone()) fn = lambda: sglang_per_token_quant_fp8(x.clone())
...@@ -86,5 +166,12 @@ def benchmark_quantization(batch_size, seq_len, provider): ...@@ -86,5 +166,12 @@ def benchmark_quantization(batch_size, seq_len, provider):
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096) # Test various hidden dimensions for correctness
test_dims = [1368, 2048, 4096]
for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_quantization.run(print_data=True) benchmark_quantization.run(print_data=True)
...@@ -75,14 +75,21 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -75,14 +75,21 @@ __global__ void per_token_quant_fp8_kernel(
c10::Float8_e4m3fnuz::from_bits()); c10::Float8_e4m3fnuz::from_bits());
#endif #endif
} }
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; if constexpr (kVecSize == 16) {
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
} else {
// Use element-wise copy for vector size 8 to ensure correctness
for (int k = 0; k < kVecSize; ++k) {
token_output[i * kVecSize + k] = output_arr[k];
}
}
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce) // 2. Baseline kernel (1 token / CTA, CUB block reduce)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
template <typename T, typename DST_DTYPE> template <typename T, typename DST_DTYPE, int kVecSize = 16>
__global__ void per_token_quant_fp8_small_batch_kernel( __global__ void per_token_quant_fp8_small_batch_kernel(
const T* __restrict__ input, const T* __restrict__ input,
DST_DTYPE* __restrict__ output_q, DST_DTYPE* __restrict__ output_q,
...@@ -100,19 +107,17 @@ __global__ void per_token_quant_fp8_small_batch_kernel( ...@@ -100,19 +107,17 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
float max_value = 0.0f; float max_value = 0.0f;
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits // Use template parameter for vector size
// Load is already vectorized, so 16 elements work for T. using vec_t = flashinfer::vec_t<T, kVecSize>;
const uint32_t VEC_SIZE = 16; const int32_t num_vec_elems = hidden_dim / kVecSize;
using vec_t = flashinfer::vec_t<T, VEC_SIZE>;
const int32_t num_vec_elems = hidden_dim / VEC_SIZE;
// Find max using vectorized loads // Find max using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) { for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(token_input + i * VEC_SIZE); input_vec.cast_load(token_input + i * kVecSize);
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) { for (uint32_t j = 0; j < kVecSize; ++j) {
float val = static_cast<float>(input_vec[j]); float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val)); max_value = fmaxf(max_value, fabsf(val));
} }
...@@ -132,11 +137,11 @@ __global__ void per_token_quant_fp8_small_batch_kernel( ...@@ -132,11 +137,11 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
// Quantize using vectorized loads // Quantize using vectorized loads
for (int32_t i = tid; i < num_vec_elems; i += block_dim) { for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(token_input + i * VEC_SIZE); input_vec.cast_load(token_input + i * kVecSize);
DST_DTYPE output_arr[VEC_SIZE]; DST_DTYPE output_arr[kVecSize];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) { for (uint32_t j = 0; j < kVecSize; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<DST_DTYPE>(val); output_arr[j] = static_cast<DST_DTYPE>(val);
...@@ -147,7 +152,14 @@ __global__ void per_token_quant_fp8_small_batch_kernel( ...@@ -147,7 +152,14 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
#endif #endif
} }
*(uint4*)(token_output + i * VEC_SIZE) = *(uint4*)output_arr; if constexpr (kVecSize == 16) {
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
} else {
// Use element-wise copy for vector size 8 to ensure correctness
for (int k = 0; k < kVecSize; ++k) {
token_output[i * kVecSize + k] = output_arr[k];
}
}
} }
} }
...@@ -158,13 +170,14 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -158,13 +170,14 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const auto input_sizes = input.sizes(); const auto input_sizes = input.sizes();
const int64_t num_tokens = input_sizes[0]; const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1]; const int64_t hidden_dim = input_sizes[1];
TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim); TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Hard-code sm_count // Hard-code sm_count
int sm_count = 132; int sm_count = 132;
constexpr int TOKENS_PER_CTA = 8; constexpr int TOKENS_PER_CTA = 8;
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
const bool use_vec16 = (hidden_dim % 16 == 0);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
if (use_warp_kernel) { if (use_warp_kernel) {
...@@ -172,23 +185,43 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -172,23 +185,43 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
dim3 block(THREADS); dim3 block(THREADS);
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()), if (use_vec16) {
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
static_cast<float*>(output_s.data_ptr()), static_cast<const scalar_t*>(input.data_ptr()),
hidden_dim, static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
num_tokens); static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
} else { } else {
// -------- baseline ----------------------------------------------------- // -------- baseline -----------------------------------------------------
constexpr int THREADS = 256; constexpr int THREADS = 256;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(THREADS); dim3 block(THREADS);
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()), if (use_vec16) {
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
static_cast<float*>(output_s.data_ptr()), static_cast<const scalar_t*>(input.data_ptr()),
hidden_dim, static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
num_tokens); static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
} }
return true; return true;
}); });
......
...@@ -36,7 +36,7 @@ def sglang_per_token_quant_fp8( ...@@ -36,7 +36,7 @@ def sglang_per_token_quant_fp8(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens,hidden_dim", "num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])), list(itertools.product([128, 256, 512], [512, 1368, 2048, 4096])),
) )
def test_per_token_quant_compare_implementations( def test_per_token_quant_compare_implementations(
num_tokens: int, num_tokens: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment