"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "24c875c6dc362911c0e0426942f93638cd25613c"
Unverified Commit 63ee26d1 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Add sgl_per_token_quant_fp8 (#4089)

parent ad55f171
import itertools
from typing import Optional, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
return output, scale
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
print(f"Scale difference: {scale_diff}")
print(f"Output difference: {output_diff}")
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
...@@ -106,6 +106,7 @@ sources = [ ...@@ -106,6 +106,7 @@ sources = [
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", "src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu", "src/sgl-kernel/csrc/speculative/eagle_utils.cu",
......
...@@ -29,6 +29,7 @@ from sgl_kernel.ops.gemm import ( ...@@ -29,6 +29,7 @@ from sgl_kernel.ops.gemm import (
int8_scaled_mm, int8_scaled_mm,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
) )
from sgl_kernel.ops.moe import moe_align_block_size from sgl_kernel.ops.moe import moe_align_block_size
from sgl_kernel.ops.sampling import ( from sgl_kernel.ops.sampling import (
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
template <typename T>
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s, const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;
const int tid = threadIdx.x;
const int block_dim = blockDim.x;
const T* token_input = input + token_idx * hidden_dim;
FP8_TYPE* token_output = output_q + token_idx * hidden_dim;
float max_value = 0.0f;
for (int i = tid; i < hidden_dim; i += block_dim) {
float val = static_cast<float>(token_input[i]);
max_value = fmaxf(max_value, fabsf(val));
}
max_value = warpReduceMax(max_value);
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
if (warpId == 0) {
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
max_value = warpReduceMax(max_value);
}
__shared__ float block_max;
if (tid == 0) {
block_max = max_value / FP8_E4M3_MAX;
output_s[token_idx] = block_max;
}
__syncthreads();
const float scale_val = 1.0f / block_max;
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = hidden_dim / vec_size;
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
vec_t input_vec;
input_vec.cast_load(token_input + i * vec_size);
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
token_output[i * vec_size + j] = output_arr[j];
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + tid; idx < hidden_dim; idx += block_dim) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(token_input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
token_output[idx] = static_cast<FP8_TYPE>(val);
#else
token_output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const auto input_sizes = input.sizes();
const int64_t num_tokens = input_sizes[0];
const int64_t hidden_dim = input_sizes[1];
const int block_size = 128;
const int num_blocks = num_tokens;
dim3 grid(num_blocks);
dim3 block(block_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens);
return true;
});
}
...@@ -160,3 +160,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r ...@@ -160,3 +160,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv); torch::Tensor new_kv);
// sgl_per_token_quant_fp8
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
...@@ -118,3 +118,11 @@ def cublas_grouped_gemm( ...@@ -118,3 +118,11 @@ def cublas_grouped_gemm(
cublas_handle, cublas_handle,
get_cuda_stream(), get_cuda_stream(),
) )
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s)
...@@ -171,6 +171,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -171,6 +171,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); "Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
} }
REGISTER_EXTENSION(_kernels) REGISTER_EXTENSION(_kernels)
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
scale = scale.reshape(-1, 1)
return output, scale
@pytest.mark.parametrize(
"num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
)
def test_per_token_quant_compare_implementations(
num_tokens: int,
hidden_dim: int,
):
device = torch.device("cuda")
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment