Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
65c24c28
Unverified
Commit
65c24c28
authored
Mar 23, 2025
by
Chunan Zeng
Committed by
GitHub
Mar 23, 2025
Browse files
[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster (#4396)
parent
3980ff1b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
191 additions
and
127 deletions
+191
-127
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+58
-48
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+60
-39
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+14
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+44
-39
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_
fp8
.py
→
sgl-kernel/benchmark/bench_per_token_group_quant_
8bit
.py
View file @
65c24c28
...
...
@@ -4,7 +4,7 @@ from typing import Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.utils
import
is_hip
...
...
@@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
...
...
@@ -24,16 +24,15 @@ def _per_token_group_quant_fp8(
N
,
# Avoid to divide zero
eps
,
# Information for
float8
fp8_min
,
fp8_max
,
# Information for
8bit data type (int8 or fp8_type_)
max_8bit
,
min_8bit
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
This function converts the tensor values into 8bit values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
...
...
@@ -47,30 +46,27 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_
fp8
(
def
triton_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
...
...
@@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
min_8bit
=
iinfo
.
min
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
...
...
@@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
max_8bit
,
min_8bit
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8(
return
x_q
,
x_s
def
sglang_per_token_group_quant_
fp8
(
def
sglang_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
):
dtype
=
torch
.
float16
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
if
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅
All
implementations match"
)
print
(
f
"✅
{
dst_dtype
}
implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
...
...
@@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size):
batch_size_range
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
group_size_range
=
[
128
]
# For DeepSeek V3/R1
dst_dtype_range
=
[
torch
.
int8
,
fp8_type_
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
))
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
,
dst_dtype_range
)
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"group_size"
],
x_names
=
[
"batch_size"
,
"seq_len"
,
"group_size"
,
"dst_dtype"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-
fp8
-performance"
,
plot_name
=
"per-token-group-quant-
8bit
-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
provider
):
dtype
=
torch
.
bfloat16
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
,
provider
):
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"triton"
:
fn
=
lambda
:
triton_per_token_group_quant_
fp8
(
x
.
clone
(),
group_size
)
fn
=
lambda
:
triton_per_token_group_quant_
8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_group_quant_
fp8
(
x
.
clone
(),
group_size
)
fn
=
lambda
:
sglang_per_token_group_quant_
8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
...
...
@@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider):
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
torch
.
int8
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
fp8_type_
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/gemm/per_token_group_quant_
fp8
.cu
→
sgl-kernel/csrc/gemm/per_token_group_quant_
8bit
.cu
View file @
65c24c28
...
...
@@ -6,8 +6,6 @@
#include "utils.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
0xffff
;
...
...
@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return
val
;
}
template
<
typename
T
,
int
GROUPS_PER_BLOCK
=
16
>
__global__
void
per_token_group_quant_
fp8
_kernel
(
template
<
typename
T
,
typename
DST_DTYPE
>
__global__
void
per_token_group_quant_
8bit
_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
int
groups_per_block
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
float
min_8bit
,
const
float
max_8bit
)
{
const
int
threads_per_group
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
block_group_id
=
blockIdx
.
x
*
GROUPS_PER_BLOCK
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
FP8_
TYPE
*
group_output
=
static_cast
<
FP8_
TYPE
*>
(
output_q
)
+
block_group_offset
;
DST_D
TYPE
*
group_output
=
static_cast
<
DST_D
TYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
(
block_group_id
+
local_group_id
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
...
...
@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel(
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
const
float
y_s
=
local_absmax
/
fp8_max
;
const
float
y_s
=
local_absmax
/
max_8bit
;
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s
;
...
...
@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel(
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
fp8_min
),
fp8_max
);
group_output
[
i
*
vec_size
+
j
]
=
FP8_
TYPE
(
q_val
);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
min_8bit
),
max_8bit
);
group_output
[
i
*
vec_size
+
j
]
=
DST_D
TYPE
(
q_val
);
}
}
}
void
sgl_per_token_group_quant_
fp8
(
void
sgl_per_token_group_quant_
8bit
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
double
min_8bit
,
double
max_8bit
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
...
...
@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8(
groups_per_block
=
2
;
}
#define LAUNCH_KERNEL(T, GPB) \
do { \
constexpr int GROUPS_PER_BLOCK = GPB; \
dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \
dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \
per_token_group_quant_fp8_kernel<T, GROUPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
(float)eps, \
(float)fp8_min, \
(float)fp8_max); \
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
groups_per_block
==
16
)
{
LAUNCH_KERNEL
(
scalar_t
,
16
);
}
else
if
(
groups_per_block
==
8
)
{
LAUNCH_KERNEL
(
scalar_t
,
8
);
}
else
if
(
groups_per_block
==
4
)
{
LAUNCH_KERNEL
(
scalar_t
,
4
);
}
else
if
(
groups_per_block
==
2
)
{
LAUNCH_KERNEL
(
scalar_t
,
2
);
}
else
{
LAUNCH_KERNEL
(
scalar_t
,
1
);
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
return
true
;
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
(
scalar_t
,
c10
::
Float8_e4m3fn
);
return
true
;
}
return
tru
e
;
return
fals
e
;
});
#undef LAUNCH_KERNEL
}
void
sgl_per_token_group_quant_int8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
);
}
sgl-kernel/csrc/torch_extension.cc
View file @
65c24c28
...
...
@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" float eps, float fp8_min, float fp8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
65c24c28
...
...
@@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8(
double
eps
,
double
fp8_min
,
double
fp8_max
);
void
sgl_per_token_group_quant_int8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
cublas_grouped_gemm
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
65c24c28
...
...
@@ -31,6 +31,7 @@ from sgl_kernel.gemm import (
int8_scaled_mm
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
)
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
65c24c28
...
...
@@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8(
)
def
sgl_per_token_group_quant_int8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
def
sgl_per_tensor_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/setup.py
View file @
65c24c28
...
...
@@ -153,7 +153,7 @@ sources = [
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/per_token_group_quant_
fp8
.cu"
,
"csrc/gemm/per_token_group_quant_
8bit
.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
...
...
sgl-kernel/tests/test_per_token_group_quant_
fp8
.py
→
sgl-kernel/tests/test_per_token_group_quant_
8bit
.py
View file @
65c24c28
import
itertools
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Tuple
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
...
...
@@ -25,16 +25,15 @@ def _per_token_group_quant_fp8(
N
,
# Avoid to divide zero
eps
,
# Information for
float8
fp8_min
,
fp8_max
,
# Information for
8bit data type (int8 or fp8_type_)
max_8bit
,
min_8bit
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
This function converts the tensor values into 8bit values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
...
...
@@ -48,30 +47,27 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_
fp8
(
def
triton_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
...
...
@@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
min_8bit
=
iinfo
.
min
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
...
...
@@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
max_8bit
,
min_8bit
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8(
return
x_q
,
x_s
def
sglang_per_token_group_quant_
fp8
(
def
sglang_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
@
pytest
.
mark
.
parametrize
(
"batch_size, seq_len, group_size"
,
"batch_size, seq_len, group_size
, dst_dtype
"
,
list
(
itertools
.
product
(
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
],
# batch_size
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
16
,
32
,
64
,
128
,
256
],
# group_size
[
torch
.
int8
,
fp8_type_
],
# dtype
)
),
)
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
):
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
x
=
torch
.
randn
(
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_
fp8
(
x
,
group_size
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_
fp8
(
x
,
group_size
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_
8bit
(
x
,
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_
8bit
(
x
,
group_size
,
dst_dtype
)
assert
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment