Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c8fd993
Unverified
Commit
2c8fd993
authored
Apr 03, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 02, 2025
Browse files
[sgl-kernel] per token group quant support COLUMN MAJOR (#4817)
parent
31da75ab
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
252 additions
and
80 deletions
+252
-80
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+7
-3
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+49
-19
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+196
-58
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
2c8fd993
...
...
@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
...
...
@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
2c8fd993
...
...
@@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return
val
;
}
template
<
typename
T
,
typename
DST_DTYPE
>
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
>
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
...
...
@@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel(
const
int
groups_per_block
,
const
float
eps
,
const
float
min_8bit
,
const
float
max_8bit
)
{
const
float
max_8bit
,
const
int
scale_num_rows
=
0
,
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
const
int
global_group_id
=
block_group_id
+
local_group_id
;
const
int
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
(
block_group_id
+
local_group_id
);
float
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
row_idx
=
global_group_id
/
scale_num_rows
;
const
int
col_idx
=
global_group_id
%
scale_num_rows
;
scale_output
=
output_s
+
(
col_idx
*
scale_stride
+
row_idx
);
}
else
{
scale_output
=
output_s
+
global_group_id
;
}
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
...
...
@@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit(
double
max_8bit
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
const
int
num_groups
=
input
.
numel
()
/
group_size
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
output_s
.
dim
(),
2
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit(
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
const
int
scale_num_rows
=
output_s
.
size
(
1
);
const
int
scale_stride
=
output_s
.
stride
(
1
);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
if (is_column_major) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
...
...
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
2c8fd993
...
...
@@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
from
sglang.srt.utils
import
is_hip
is_hip
_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip
_
else
torch
.
float8_e4m3fn
_
is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_
is_hip
else
torch
.
float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_
8bit
(
def
_per_token_group_quant_
fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
...
...
@@ -25,15 +25,16 @@ def _per_token_group_quant_8bit(
N
,
# Avoid to divide zero
eps
,
# Information for
8bit data type (int8 or fp8_type_)
max_8bit
,
min_8bit
,
# Information for
float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
...
...
@@ -47,8 +48,57 @@ def _per_token_group_quant_8bit(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_s
=
_absmax
/
fp8_max
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
group_size
y_q_ptr
+=
g_id
*
group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
y_s_ptr
+=
scale_col
*
y_s_col_stride
+
scale_row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
...
...
@@ -57,17 +107,22 @@ def _per_token_group_quant_8bit(
def
triton_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
...
...
@@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
min_8bit
=
iinfo
.
min
if
dtype
==
torch
.
int8
:
finfo
=
torch
.
iinfo
(
dtype
)
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
if
dtype
==
torch
.
int8
:
fp8_max
=
127.0
else
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_8bit
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
max_8bit
,
min_8bit
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
if
column_major_scales
:
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x_s
.
stride
(
1
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
...
...
@@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit(
def
sglang_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
if
dst_
dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_
dtype
)
if
dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_
dtype
)
f8_info
=
torch
.
finfo
(
dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
...
...
@@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit(
@
pytest
.
mark
.
parametrize
(
"
batch_size, seq_len
, group_size, dst_dtype"
,
"
num_tokens, hidden_dim
, group_size, dst_dtype
, column_major_scales, scale_tma_aligned
"
,
list
(
itertools
.
product
(
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
],
# batch_size
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
16
,
32
,
64
,
128
,
256
],
# group_size
[
1
27
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
torch
.
int8
,
fp8_type_
],
# dtype
[
False
,
True
],
# column_major_scales
[
False
,
True
],
# scale_tma_aligned
)
),
)
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
def
test_per_token_group_quant_with_column_major
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
column_major_scales
,
scale_tma_aligned
,
):
x
=
torch
.
randn
(
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
if
not
column_major_scales
and
scale_tma_aligned
:
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
,
group_size
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
,
group_size
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
)
assert
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment