Unverified Commit 2c8fd993 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[sgl-kernel] per token group quant support COLUMN MAJOR (#4817)

parent 31da75ab
...@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit( ...@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
def calculate_diff(batch_size, seq_len, group_size, dst_dtype): def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
device = torch.device("cuda") device = torch.device("cuda")
hidden_dim = group_size * 2 hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype x.clone(), group_size, dst_dtype
...@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): ...@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
device = torch.device("cuda") device = torch.device("cuda")
hidden_dim = 7168 hidden_dim = 7168
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
......
...@@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { ...@@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return val; return val;
} }
template <typename T, typename DST_DTYPE> template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false>
__global__ void per_token_group_quant_8bit_kernel( __global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, const T* __restrict__ input,
void* __restrict__ output_q, void* __restrict__ output_q,
...@@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel(
const int groups_per_block, const int groups_per_block,
const float eps, const float eps,
const float min_8bit, const float min_8bit,
const float max_8bit) { const float max_8bit,
const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16; const int threads_per_group = 16;
const int local_group_id = threadIdx.x / threads_per_group; const int local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group; const int lane_id = threadIdx.x % threads_per_group;
const int block_group_id = blockIdx.x * groups_per_block; const int block_group_id = blockIdx.x * groups_per_block;
const int block_group_offset = (block_group_id + local_group_id) * group_size; const int global_group_id = block_group_id + local_group_id;
const int block_group_offset = global_group_id * group_size;
float local_absmax = eps; float local_absmax = eps;
const T* group_input = input + block_group_offset; const T* group_input = input + block_group_offset;
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset; DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
float* scale_output = output_s + (block_group_id + local_group_id); float* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
const int row_idx = global_group_id / scale_num_rows;
const int col_idx = global_group_id % scale_num_rows;
scale_output = output_s + (col_idx * scale_stride + row_idx);
} else {
scale_output = output_s + global_group_id;
}
constexpr uint32_t vec_size = 16 / sizeof(T); constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>; using vec_t = flashinfer::vec_t<T, vec_size>;
...@@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit( ...@@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit(
double max_8bit) { double max_8bit) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
const int num_groups = input.numel() / group_size; const int num_groups = input.numel() / group_size;
CHECK_EQ(input.numel() % group_size, 0); CHECK_EQ(input.numel() % group_size, 0);
CHECK_EQ(output_s.dim(), 2);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit( ...@@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit(
const int num_blocks = num_groups / groups_per_block; const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP; const int num_threads = groups_per_block * THREADS_PER_GROUP;
#define LAUNCH_KERNEL(T, DST_DTYPE) \ const bool is_column_major = output_s.stride(0) < output_s.stride(1);
do { \ const int scale_num_rows = output_s.size(1);
dim3 grid(num_blocks); \ const int scale_stride = output_s.stride(1);
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \ #define LAUNCH_KERNEL(T, DST_DTYPE) \
static_cast<T*>(input.data_ptr()), \ do { \
output_q.data_ptr(), \ dim3 grid(num_blocks); \
static_cast<float*>(output_s.data_ptr()), \ dim3 block(num_threads); \
group_size, \ if (is_column_major) { \
num_groups, \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
groups_per_block, \ static_cast<T*>(input.data_ptr()), \
(float)eps, \ output_q.data_ptr(), \
(float)min_8bit, \ static_cast<float*>(output_s.data_ptr()), \
(float)max_8bit); \ group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0) } while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
......
...@@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_ ...@@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@triton.jit @triton.jit
def _per_token_group_quant_8bit( def _per_token_group_quant_fp8(
# Pointers to inputs and output # Pointers to inputs and output
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
...@@ -25,15 +25,16 @@ def _per_token_group_quant_8bit( ...@@ -25,15 +25,16 @@ def _per_token_group_quant_8bit(
N, N,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for 8bit data type (int8 or fp8_type_) # Information for float8
max_8bit, fp8_min,
min_8bit, fp8_max,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
): ):
"""A Triton-accelerated function to perform per-token-group quantization on a """A Triton-accelerated function to perform per-token-group quantization on a
tensor. tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
""" """
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
...@@ -47,8 +48,57 @@ def _per_token_group_quant_8bit( ...@@ -47,8 +48,57 @@ def _per_token_group_quant_8bit(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / max_8bit y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty) y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * group_size
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
...@@ -57,17 +107,22 @@ def _per_token_group_quant_8bit( ...@@ -57,17 +107,22 @@ def _per_token_group_quant_8bit(
def triton_per_token_group_quant_8bit( def triton_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization. quantized tensor along with the scaling factor used for quantization.
Args: Args:
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. dtype: The dype of output tensor.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
""" """
...@@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit( ...@@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
if dst_dtype == torch.int8: if dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype) finfo = torch.iinfo(dtype)
max_8bit = iinfo.max
min_8bit = iinfo.min
else: else:
finfo = torch.finfo(dst_dtype) finfo = torch.finfo(dtype)
max_8bit = finfo.max
min_8bit = finfo.min fp8_max = finfo.max
if _is_hip:
if dtype == torch.int8:
fp8_max = 127.0
else:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
x_s = torch.empty( if column_major_scales:
x.shape[:-1] + (x.shape[-1] // group_size,), if scale_tma_aligned:
device=x.device, # aligned to 4 * sizeof(float)
dtype=torch.float32, aligned_size = (x.shape[-2] + 3) // 4 * 4
) x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N) BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1 num_stages = 1
_per_token_group_quant_8bit[(M,)]( if column_major_scales:
x, _per_token_group_quant_fp8_colmajor[(M,)](
x_q, x,
x_s, x_q,
group_size, x_s,
N, group_size,
eps, x.shape[1],
max_8bit, x_s.stride(1),
min_8bit, eps,
BLOCK=BLOCK, fp8_min=fp8_min,
num_warps=num_warps, fp8_max=fp8_max,
num_stages=num_stages, BLOCK=BLOCK,
) num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s return x_q, x_s
...@@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit( ...@@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit(
def sglang_per_token_group_quant_8bit( def sglang_per_token_group_quant_8bit(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype) x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty( M = x.numel() // group_size
x.shape[:-1] + (x.shape[-1] // group_size,), N = group_size
device=x.device, if column_major_scales:
dtype=torch.float32, if scale_tma_aligned:
) # aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
if dst_dtype == torch.int8: if dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype) iinfo = torch.iinfo(dtype)
int8_max = iinfo.max int8_max = iinfo.max
int8_min = iinfo.min int8_min = iinfo.min
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
else: else:
f8_info = torch.finfo(dst_dtype) f8_info = torch.finfo(dtype)
fp8_max = f8_info.max fp8_max = f8_info.max
fp8_min = f8_info.min fp8_min = f8_info.min
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
...@@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit( ...@@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, seq_len, group_size, dst_dtype", "num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned",
list( list(
itertools.product( itertools.product(
[1, 2, 4, 8, 16, 32, 64, 128], # batch_size [127, 128, 512, 1024, 4096, 8192], # num_tokens
[64, 128, 256, 512, 1024, 2048], # seq_len [256, 512, 1024, 2048, 4096], # hidden_dim
[16, 32, 64, 128, 256], # group_size [8, 16, 32, 64, 128], # group_size
[torch.int8, fp8_type_], # dtype [torch.int8, fp8_type_], # dtype
[False, True], # column_major_scales
[False, True], # scale_tma_aligned
) )
), ),
) )
def test_per_token_group_quant_compare_implementations( def test_per_token_group_quant_with_column_major(
batch_size, seq_len, group_size, dst_dtype num_tokens,
hidden_dim,
group_size,
dst_dtype,
column_major_scales,
scale_tma_aligned,
): ):
x = torch.randn( if not column_major_scales and scale_tma_aligned:
(batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16 return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x,
group_size,
eps=1e-10,
dtype=dst_dtype,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
) )
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(x, group_size, dst_dtype) x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(x, group_size, dst_dtype) x,
group_size,
eps=1e-10,
dtype=dst_dtype,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
)
assert torch.allclose( assert torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
) )
assert torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5) assert torch.allclose(
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment