"docs/source/guide_cn/data-process.rst" did not exist on "9c08cd6bf055eea253cd91d1e3765576004fd9a7"
Unverified Commit bb418ced authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

optimize per token group quant fp8 (#3490)

parent fdf04a14
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
def calculate_diff(batch_size, seq_len, group_size):
dtype = torch.float16
device = torch.device("cuda")
hidden_dim = group_size * 2
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)
x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x.clone(), group_size)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x.clone(), group_size)
if torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [1, 2, 4, 8, 16, 32, 64]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
group_size_range = [128] # For DeepSeek V3/R1
configs = list(itertools.product(batch_size_range, seq_len_range, group_size_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "group_size"],
x_vals=configs,
line_arg="provider",
line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-group-quant-fp8-performance",
args={},
)
)
def benchmark(batch_size, seq_len, group_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_dim = group_size * 2
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton":
fn = lambda: triton_per_token_group_quant_fp8(x.clone(), group_size)
elif provider == "sglang":
fn = lambda: sglang_per_token_group_quant_fp8(x.clone(), group_size)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=128, group_size=64)
benchmark.run(print_data=True)
...@@ -100,6 +100,7 @@ sources = [ ...@@ -100,6 +100,7 @@ sources = [
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/eagle_utils.cu", "src/sgl-kernel/csrc/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative_sampling.cu", "src/sgl-kernel/csrc/speculative_sampling.cu",
"src/sgl-kernel/csrc/per_token_group_quant_fp8.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
......
...@@ -29,6 +29,7 @@ from sgl_kernel.ops import ( ...@@ -29,6 +29,7 @@ from sgl_kernel.ops import (
register_graph_buffers, register_graph_buffers,
rmsnorm, rmsnorm,
sampling_scaling_penalties, sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul, silu_and_mul,
top_k_renorm_prob, top_k_renorm_prob,
top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_probs,
...@@ -65,4 +66,5 @@ __all__ = [ ...@@ -65,4 +66,5 @@ __all__ = [
"tree_speculative_sampling_target_only", "tree_speculative_sampling_target_only",
"build_tree_kernel_efficient", "build_tree_kernel_efficient",
"build_tree_kernel", "build_tree_kernel",
"sgl_per_token_group_quant_fp8",
] ]
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include "utils.h"
using FP8_TYPE = c10::Float8_e4m3fn;
__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) {
if (tid < 8) {
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
}
return smem[0];
}
template <typename T>
__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q,
float* __restrict__ output_s, const int group_size,
const int num_groups, const float eps, const float fp8_min,
const float fp8_max) {
const int groups_per_block = 16;
const int block_group_id = blockIdx.x * groups_per_block;
const int tid = threadIdx.x;
const int local_group_id = tid / 16; // Each 16 threads handle one group
const int local_tid = tid % 16; // Thread ID within the group
__shared__ float s_absmax[16][17]; // Use 17 instead of 16 to avoid bank conflicts
// Local maximum value for each thread
float local_absmax = eps;
// Ensure this block doesn't process out-of-bounds groups
if (block_group_id + local_group_id < num_groups) {
// Calculate input/output pointers for current group
const T* group_input = input + (block_group_id + local_group_id) * group_size;
FP8_TYPE* group_output = static_cast<FP8_TYPE*>(output_q) + (block_group_id + local_group_id) * group_size;
float* scale_output = output_s + block_group_id + local_group_id;
// Calculate local maximum absolute value
for (int i = local_tid; i < group_size; i += 16) {
float val = static_cast<float>(group_input[i]);
float abs_val = fabsf(val);
local_absmax = fmaxf(local_absmax, abs_val);
}
// Store in shared memory
s_absmax[local_group_id][local_tid] = local_absmax;
__syncthreads();
// Perform reduction within each group
if (local_tid < 8) {
WarpReduce(&s_absmax[local_group_id][0], local_tid);
}
__syncthreads();
// Get the maximum value for this group
const float group_absmax = s_absmax[local_group_id][0];
const float y_s = group_absmax / fp8_max;
// Only the first thread in each group writes the scale
if (local_tid == 0) {
*scale_output = y_s;
}
// Quantize the data
for (int i = local_tid; i < group_size; i += 16) {
float val = static_cast<float>(group_input[i]);
float q_val = fminf(fmaxf(val / y_s, fp8_min), fp8_max);
group_output[i] = FP8_TYPE(q_val);
}
}
}
void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s,
int64_t group_size, double eps, double fp8_min, double fp8_max) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const int num_groups = input.numel() / group_size;
CHECK_EQ(input.numel() % group_size, 0);
// Each block processes 16 groups, adjust grid size accordingly
dim3 grid((num_groups + 15) / 16);
dim3 block(256); // Keep 256 threads, each 16 threads handle one group
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), output_q.data_ptr(), static_cast<float*>(output_s.data_ptr()),
group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max);
return true;
});
}
...@@ -143,3 +143,7 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind ...@@ -143,3 +143,7 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num); int64_t depth, int64_t draft_token_num);
// sgl_per_token_group_quant_fp8
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
double eps, double fp8_min, double fp8_max);
...@@ -579,3 +579,17 @@ def build_tree_kernel( ...@@ -579,3 +579,17 @@ def build_tree_kernel(
depth, depth,
draft_token_num, draft_token_num,
) )
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
) -> None:
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
)
...@@ -153,6 +153,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -153,6 +153,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()"); "int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel); m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
// per_token_group_quant_fp8
m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
} }
REGISTER_EXTENSION(_kernels) REGISTER_EXTENSION(_kernels)
import itertools
from typing import Any, Dict, List, Optional, Tuple
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
@pytest.mark.parametrize(
"batch_size, seq_len, group_size",
list(
itertools.product(
[1, 2, 4, 8, 16], # batch_size
[64, 128, 256, 512, 1024, 2048], # seq_len
[64, 128, 256], # group_size
)
),
)
def test_per_token_group_quant_compare_implementations(batch_size, seq_len, group_size):
x = torch.randn(
(batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16
)
x_q_triton, x_s_triton = triton_per_token_group_quant_fp8(x, group_size)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_fp8(x, group_size)
assert torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
)
assert torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment