Unverified Commit ad55f171 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[quant kernel] sgl-kernel support per_tensor_quant fp8 (#3786)

parent 361971b8
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
def vllm_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, scale)
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-tensor-quant-fp8-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_scaled_fp8_quant(x.clone())
elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark.run(print_data=True)
...@@ -106,6 +106,7 @@ sources = [ ...@@ -106,6 +106,7 @@ sources = [
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu", "src/sgl-kernel/csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu", "src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
......
...@@ -27,6 +27,7 @@ from sgl_kernel.ops.gemm import ( ...@@ -27,6 +27,7 @@ from sgl_kernel.ops.gemm import (
fp8_blockwise_scaled_mm, fp8_blockwise_scaled_mm,
fp8_scaled_mm, fp8_scaled_mm,
int8_scaled_mm, int8_scaled_mm,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_fp8,
) )
from sgl_kernel.ops.moe import moe_align_block_size from sgl_kernel.ops.moe import moe_align_block_size
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
template <typename T>
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
const int64_t num_elements) {
float max_value = 0.0f;
unsigned int tid = threadIdx.x;
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = num_elements / vec_size;
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
vec_t input_vec;
input_vec.cast_load(input + i * vec_size);
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
max_value = fmaxf(max_value, fabsf(val));
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = static_cast<float>(input[idx]);
max_value = fmaxf(max_value, fabsf(val));
}
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
if (tid == 0) {
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
}
}
template <typename T>
__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output,
const float* __restrict__ scale, const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int32_t num_vec_elems = num_elements / vec_size;
for (int32_t i = gid; i < num_vec_elems; i += grid_size) {
vec_t input_vec;
input_vec.cast_load(input + i * vec_size);
FP8_TYPE output_arr[vec_size];
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val);
#else
output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
output[i * vec_size + j] = output_arr[j];
}
}
const int32_t remaining_start = num_vec_elems * vec_size;
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM
output[idx] = static_cast<FP8_TYPE>(val);
#else
output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(value, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, bool is_static) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const int block_size = 256;
const int num_elements = input.numel();
const int num_blocks = min((num_elements + block_size - 1) / block_size, 1024);
dim3 grid(num_blocks);
dim3 block(block_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
if (is_static == false) {
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
}
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), num_elements);
return true;
});
}
...@@ -92,6 +92,7 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T ...@@ -92,6 +92,7 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
double eps, double fp8_min, double fp8_max); double eps, double fp8_min, double fp8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights, void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype, const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
int64_t cublas_handle, int64_t cuda_stream); int64_t cublas_handle, int64_t cuda_stream);
......
...@@ -91,6 +91,15 @@ def sgl_per_token_group_quant_fp8( ...@@ -91,6 +91,15 @@ def sgl_per_token_group_quant_fp8(
) )
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
def cublas_grouped_gemm( def cublas_grouped_gemm(
inputs: List[torch.Tensor], inputs: List[torch.Tensor],
weights: List[torch.Tensor], weights: List[torch.Tensor],
......
...@@ -90,6 +90,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -90,6 +90,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
" float eps, float fp8_min, float fp8_max) -> ()"); " float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
m.def( m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
......
import itertools
from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
def vllm_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, scale)
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
@pytest.mark.parametrize(
"num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
)
def test_per_tensor_quant_compare_implementations(
num_tokens: int,
hidden_dim: int,
):
device = torch.device("cuda")
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
)
scale = torch.rand(1, dtype=torch.float32, device=device)
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment