Unverified Commit 65c24c28 authored by Chunan Zeng's avatar Chunan Zeng Committed by GitHub
Browse files

[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster (#4396)

parent 3980ff1b
...@@ -4,7 +4,7 @@ from typing import Tuple ...@@ -4,7 +4,7 @@ from typing import Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
...@@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn ...@@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@triton.jit @triton.jit
def _per_token_group_quant_fp8( def _per_token_group_quant_8bit(
# Pointers to inputs and output # Pointers to inputs and output
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
...@@ -24,16 +24,15 @@ def _per_token_group_quant_fp8( ...@@ -24,16 +24,15 @@ def _per_token_group_quant_fp8(
N, N,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for 8bit data type (int8 or fp8_type_)
fp8_min, max_8bit,
fp8_max, min_8bit,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
): ):
"""A Triton-accelerated function to perform per-token-group quantization on a """A Triton-accelerated function to perform per-token-group quantization on a
tensor. tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
""" """
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
...@@ -47,30 +46,27 @@ def _per_token_group_quant_fp8( ...@@ -47,30 +46,27 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max y_s = _absmax / max_8bit
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_fp8( def triton_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization. quantized tensor along with the scaling factor used for quantization.
Args: Args:
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
""" """
...@@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8( ...@@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) if dst_dtype == torch.int8:
fp8_max = finfo.max iinfo = torch.iinfo(dst_dtype)
max_8bit = iinfo.max
fp8_min = -fp8_max min_8bit = iinfo.min
else:
finfo = torch.finfo(dst_dtype)
max_8bit = finfo.max
min_8bit = finfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
x_s = torch.empty( x_s = torch.empty(
...@@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8( ...@@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1 num_stages = 1
_per_token_group_quant_fp8[(M,)]( _per_token_group_quant_8bit[(M,)](
x, x,
x_q, x_q,
x_s, x_s,
group_size, group_size,
N, N,
eps, eps,
fp8_min=fp8_min, max_8bit,
fp8_max=fp8_max, min_8bit,
BLOCK=BLOCK, BLOCK=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
...@@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8( ...@@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
def sglang_per_token_group_quant_fp8( def sglang_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) if dst_dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype)
int8_max = iinfo.max
int8_min = iinfo.min
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
else:
f8_info = torch.finfo(dst_dtype)
fp8_max = f8_info.max
fp8_min = f8_info.min
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
def calculate_diff(batch_size, seq_len, group_size): def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
dtype = torch.float16
device = torch.device("cuda") device = torch.device("cuda")
hidden_dim = group_size * 2 hidden_dim = group_size * 2
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x.clone(), group_size) x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x.clone(), group_size) x.clone(), group_size, dst_dtype
)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype
)
if torch.allclose( if torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5): ) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5):
print("✅ All implementations match") print(f"✅ {dst_dtype} implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
...@@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size): ...@@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size):
batch_size_range = [1, 2, 4, 8, 16, 32, 64] batch_size_range = [1, 2, 4, 8, 16, 32, 64]
seq_len_range = [64, 128, 256, 512, 1024, 2048] seq_len_range = [64, 128, 256, 512, 1024, 2048]
group_size_range = [128] # For DeepSeek V3/R1 group_size_range = [128] # For DeepSeek V3/R1
dst_dtype_range = [torch.int8, fp8_type_]
configs = list(itertools.product(batch_size_range, seq_len_range, group_size_range)) configs = list(
itertools.product(
batch_size_range, seq_len_range, group_size_range, dst_dtype_range
)
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "group_size"], x_names=["batch_size", "seq_len", "group_size", "dst_dtype"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"], line_names=["Triton", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-group-quant-fp8-performance", plot_name="per-token-group-quant-8bit-performance",
args={}, args={},
) )
) )
def benchmark(batch_size, seq_len, group_size, provider): def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
dtype = torch.bfloat16
device = torch.device("cuda") device = torch.device("cuda")
hidden_dim = 7168 hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "triton": if provider == "triton":
fn = lambda: triton_per_token_group_quant_fp8(x.clone(), group_size) fn = lambda: triton_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_group_quant_fp8(x.clone(), group_size) fn = lambda: sglang_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
...@@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider): ...@@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider):
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=128, group_size=64) calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8)
calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_)
benchmark.run(print_data=True) benchmark.run(print_data=True)
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
#include "utils.h" #include "utils.h"
using FP8_TYPE = c10::Float8_e4m3fn;
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff; unsigned mask = 0xffff;
...@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { ...@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return val; return val;
} }
template <typename T, int GROUPS_PER_BLOCK = 16> template <typename T, typename DST_DTYPE>
__global__ void per_token_group_quant_fp8_kernel( __global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, const T* __restrict__ input,
void* __restrict__ output_q, void* __restrict__ output_q,
float* __restrict__ output_s, float* __restrict__ output_s,
const int group_size, const int group_size,
const int num_groups, const int num_groups,
const int groups_per_block,
const float eps, const float eps,
const float fp8_min, const float min_8bit,
const float fp8_max) { const float max_8bit) {
const int threads_per_group = 16; const int threads_per_group = 16;
const int local_group_id = threadIdx.x / threads_per_group; const int local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group; const int lane_id = threadIdx.x % threads_per_group;
const int block_group_id = blockIdx.x * GROUPS_PER_BLOCK; const int block_group_id = blockIdx.x * groups_per_block;
const int block_group_offset = (block_group_id + local_group_id) * group_size; const int block_group_offset = (block_group_id + local_group_id) * group_size;
float local_absmax = eps; float local_absmax = eps;
const T* group_input = input + block_group_offset; const T* group_input = input + block_group_offset;
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + block_group_offset; DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
float* scale_output = output_s + (block_group_id + local_group_id); float* scale_output = output_s + (block_group_id + local_group_id);
constexpr uint32_t vec_size = 16 / sizeof(T); constexpr uint32_t vec_size = 16 / sizeof(T);
...@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel( ...@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel(
local_absmax = GroupReduceMax(local_absmax, lane_id); local_absmax = GroupReduceMax(local_absmax, lane_id);
const float y_s = local_absmax / fp8_max; const float y_s = local_absmax / max_8bit;
if (lane_id == 0) { if (lane_id == 0) {
*scale_output = y_s; *scale_output = y_s;
...@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel( ...@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel(
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) { for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]); float val = static_cast<float>(input_vec[j]);
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max); float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
group_output[i * vec_size + j] = FP8_TYPE(q_val); group_output[i * vec_size + j] = DST_DTYPE(q_val);
} }
} }
} }
void sgl_per_token_group_quant_fp8( void sgl_per_token_group_quant_8bit(
torch::Tensor input, torch::Tensor input,
torch::Tensor output_q, torch::Tensor output_q,
torch::Tensor output_s, torch::Tensor output_s,
int64_t group_size, int64_t group_size,
double eps, double eps,
double fp8_min, double min_8bit,
double fp8_max) { double max_8bit) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
CHECK_INPUT(output_s); CHECK_INPUT(output_s);
...@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8( ...@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8(
groups_per_block = 2; groups_per_block = 2;
} }
#define LAUNCH_KERNEL(T, GPB) \ auto dst_type = output_q.scalar_type();
do { \ const int num_blocks = num_groups / groups_per_block;
constexpr int GROUPS_PER_BLOCK = GPB; \ const int num_threads = groups_per_block * THREADS_PER_GROUP;
dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \
dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \ #define LAUNCH_KERNEL(T, DST_DTYPE) \
per_token_group_quant_fp8_kernel<T, GROUPS_PER_BLOCK><<<grid, block, 0, stream>>>( \ do { \
static_cast<T*>(input.data_ptr()), \ dim3 grid(num_blocks); \
output_q.data_ptr(), \ dim3 block(num_threads); \
static_cast<float*>(output_s.data_ptr()), \ per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
group_size, \ static_cast<T*>(input.data_ptr()), \
num_groups, \ output_q.data_ptr(), \
(float)eps, \ static_cast<float*>(output_s.data_ptr()), \
(float)fp8_min, \ group_size, \
(float)fp8_max); \ num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} while (0) } while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
if (groups_per_block == 16) { if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, 16); LAUNCH_KERNEL(scalar_t, int8_t);
} else if (groups_per_block == 8) { return true;
LAUNCH_KERNEL(scalar_t, 8); } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
} else if (groups_per_block == 4) { LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
LAUNCH_KERNEL(scalar_t, 4); return true;
} else if (groups_per_block == 2) {
LAUNCH_KERNEL(scalar_t, 2);
} else {
LAUNCH_KERNEL(scalar_t, 1);
} }
return true; return false;
}); });
#undef LAUNCH_KERNEL #undef LAUNCH_KERNEL
} }
void sgl_per_token_group_quant_int8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max);
}
...@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" float eps, float fp8_min, float fp8_max) -> ()"); " float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
......
...@@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8( ...@@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8(
double eps, double eps,
double fp8_min, double fp8_min,
double fp8_max); double fp8_max);
void sgl_per_token_group_quant_int8(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void cublas_grouped_gemm( void cublas_grouped_gemm(
......
...@@ -31,6 +31,7 @@ from sgl_kernel.gemm import ( ...@@ -31,6 +31,7 @@ from sgl_kernel.gemm import (
int8_scaled_mm, int8_scaled_mm,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
) )
from sgl_kernel.moe import moe_align_block_size, topk_softmax from sgl_kernel.moe import moe_align_block_size, topk_softmax
......
...@@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8( ...@@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8(
) )
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8(
input, output_q, output_s, group_size, eps, int8_min, int8_max
)
def sgl_per_tensor_quant_fp8( def sgl_per_tensor_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
output_q: torch.Tensor, output_q: torch.Tensor,
......
...@@ -153,7 +153,7 @@ sources = [ ...@@ -153,7 +153,7 @@ sources = [
"csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_gemm_kernel.cu",
"csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"csrc/gemm/int8_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu",
"csrc/gemm/per_token_group_quant_fp8.cu", "csrc/gemm/per_token_group_quant_8bit.cu",
"csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu",
"csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_align_kernel.cu",
......
import itertools import itertools
from typing import Any, Dict, List, Optional, Tuple from typing import Tuple
import pytest import pytest
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@triton.jit @triton.jit
def _per_token_group_quant_fp8( def _per_token_group_quant_8bit(
# Pointers to inputs and output # Pointers to inputs and output
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
...@@ -25,16 +25,15 @@ def _per_token_group_quant_fp8( ...@@ -25,16 +25,15 @@ def _per_token_group_quant_fp8(
N, N,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for 8bit data type (int8 or fp8_type_)
fp8_min, max_8bit,
fp8_max, min_8bit,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
): ):
"""A Triton-accelerated function to perform per-token-group quantization on a """A Triton-accelerated function to perform per-token-group quantization on a
tensor. tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
""" """
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
...@@ -48,30 +47,27 @@ def _per_token_group_quant_fp8( ...@@ -48,30 +47,27 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max y_s = _absmax / max_8bit
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_fp8( def triton_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization. quantized tensor along with the scaling factor used for quantization.
Args: Args:
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
""" """
...@@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8( ...@@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) if dst_dtype == torch.int8:
fp8_max = finfo.max iinfo = torch.iinfo(dst_dtype)
max_8bit = iinfo.max
fp8_min = -fp8_max min_8bit = iinfo.min
else:
finfo = torch.finfo(dst_dtype)
max_8bit = finfo.max
min_8bit = finfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
x_s = torch.empty( x_s = torch.empty(
...@@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8( ...@@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1 num_stages = 1
_per_token_group_quant_fp8[(M,)]( _per_token_group_quant_8bit[(M,)](
x, x,
x_q, x_q,
x_s, x_s,
group_size, group_size,
N, N,
eps, eps,
fp8_min=fp8_min, max_8bit,
fp8_max=fp8_max, min_8bit,
BLOCK=BLOCK, BLOCK=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
...@@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8( ...@@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
def sglang_per_token_group_quant_fp8( def sglang_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) if dst_dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype)
int8_max = iinfo.max
int8_min = iinfo.min
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
else:
f8_info = torch.finfo(dst_dtype)
fp8_max = f8_info.max
fp8_min = f8_info.min
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, seq_len, group_size", "batch_size, seq_len, group_size, dst_dtype",
list( list(
itertools.product( itertools.product(
[1, 2, 4, 8, 16, 32, 64, 128], # batch_size [1, 2, 4, 8, 16, 32, 64, 128], # batch_size
[64, 128, 256, 512, 1024, 2048], # seq_len [64, 128, 256, 512, 1024, 2048], # seq_len
[16, 32, 64, 128, 256], # group_size [16, 32, 64, 128, 256], # group_size
[torch.int8, fp8_type_], # dtype
) )
), ),
) )
def test_per_token_group_quant_compare_implementations(batch_size, seq_len, group_size): def test_per_token_group_quant_compare_implementations(
batch_size, seq_len, group_size, dst_dtype
):
x = torch.randn( x = torch.randn(
(batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16 (batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16
) )
x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x, group_size) x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x, group_size) x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
assert torch.allclose( assert torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment