Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
65c24c28
Unverified
Commit
65c24c28
authored
Mar 23, 2025
by
Chunan Zeng
Committed by
GitHub
Mar 23, 2025
Browse files
[Quant Kernel] refactored per token group quant fp8 to support int8 up-to 2x faster (#4396)
parent
3980ff1b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
191 additions
and
127 deletions
+191
-127
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+58
-48
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+60
-39
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+14
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+44
-39
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_
fp8
.py
→
sgl-kernel/benchmark/bench_per_token_group_quant_
8bit
.py
View file @
65c24c28
...
@@ -4,7 +4,7 @@ from typing import Tuple
...
@@ -4,7 +4,7 @@ from typing import Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
...
@@ -13,7 +13,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
# Pointers to inputs and output
y_ptr
,
y_ptr
,
y_q_ptr
,
y_q_ptr
,
...
@@ -24,16 +24,15 @@ def _per_token_group_quant_fp8(
...
@@ -24,16 +24,15 @@ def _per_token_group_quant_fp8(
N
,
N
,
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for
float8
# Information for
8bit data type (int8 or fp8_type_)
fp8_min
,
max_8bit
,
fp8_max
,
min_8bit
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
):
"""A Triton-accelerated function to perform per-token-group quantization on a
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
"""
"""
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
g_id
=
tl
.
program_id
(
0
)
...
@@ -47,30 +46,27 @@ def _per_token_group_quant_fp8(
...
@@ -47,30 +46,27 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_
fp8
(
def
triton_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
quantized tensor along with the scaling factor used for quantization.
Args:
Args:
x: The input tenosr with ndim >= 2.
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
"""
...
@@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8(
...
@@ -79,12 +75,16 @@ def triton_per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
if
dst_dtype
==
torch
.
int8
:
fp8_max
=
finfo
.
max
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
fp8_min
=
-
fp8_max
min_8bit
=
iinfo
.
min
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
M
=
x
.
numel
()
//
group_size
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
...
@@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8(
...
@@ -97,15 +97,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
group_size
,
group_size
,
N
,
N
,
eps
,
eps
,
fp8_min
=
fp8_min
,
max_8bit
,
fp8_max
=
fp8_max
,
min_8bit
,
BLOCK
=
BLOCK
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8(
...
@@ -114,50 +114,55 @@ def triton_per_token_group_quant_fp8(
return
x_q
,
x_s
return
x_q
,
x_s
def
sglang_per_token_group_quant_
fp8
(
def
sglang_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
):
assert
(
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
device
=
x
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
return
x_q
,
x_s
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
):
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
hidden_dim
=
group_size
*
2
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
x
.
clone
(),
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
if
torch
.
allclose
(
if
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
):
)
and
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅
All
implementations match"
)
print
(
f
"✅
{
dst_dtype
}
implementations match"
)
else
:
else
:
print
(
"❌ Implementations differ"
)
print
(
"❌ Implementations differ"
)
...
@@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size):
...
@@ -165,36 +170,40 @@ def calculate_diff(batch_size, seq_len, group_size):
batch_size_range
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
batch_size_range
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
group_size_range
=
[
128
]
# For DeepSeek V3/R1
group_size_range
=
[
128
]
# For DeepSeek V3/R1
dst_dtype_range
=
[
torch
.
int8
,
fp8_type_
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
))
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
,
dst_dtype_range
)
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"group_size"
],
x_names
=
[
"batch_size"
,
"seq_len"
,
"group_size"
,
"dst_dtype"
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-
fp8
-performance"
,
plot_name
=
"per-token-group-quant-
8bit
-performance"
,
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
provider
):
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
7168
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"triton"
:
if
provider
==
"triton"
:
fn
=
lambda
:
triton_per_token_group_quant_
fp8
(
x
.
clone
(),
group_size
)
fn
=
lambda
:
triton_per_token_group_quant_
8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
elif
provider
==
"sglang"
:
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_group_quant_
fp8
(
x
.
clone
(),
group_size
)
fn
=
lambda
:
sglang_per_token_group_quant_
8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
...
@@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider):
...
@@ -203,6 +212,7 @@ def benchmark(batch_size, seq_len, group_size, provider):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
torch
.
int8
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
fp8_type_
)
benchmark
.
run
(
print_data
=
True
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/gemm/per_token_group_quant_
fp8
.cu
→
sgl-kernel/csrc/gemm/per_token_group_quant_
8bit
.cu
View file @
65c24c28
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "utils.h"
#include "utils.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
0xffff
;
unsigned
mask
=
0xffff
;
...
@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
...
@@ -18,27 +16,28 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return
val
;
return
val
;
}
}
template
<
typename
T
,
int
GROUPS_PER_BLOCK
=
16
>
template
<
typename
T
,
typename
DST_DTYPE
>
__global__
void
per_token_group_quant_
fp8
_kernel
(
__global__
void
per_token_group_quant_
8bit
_kernel
(
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
group_size
,
const
int
num_groups
,
const
int
num_groups
,
const
int
groups_per_block
,
const
float
eps
,
const
float
eps
,
const
float
fp8_min
,
const
float
min_8bit
,
const
float
fp8_max
)
{
const
float
max_8bit
)
{
const
int
threads_per_group
=
16
;
const
int
threads_per_group
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
block_group_id
=
blockIdx
.
x
*
GROUPS_PER_BLOCK
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
float
local_absmax
=
eps
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
const
T
*
group_input
=
input
+
block_group_offset
;
FP8_
TYPE
*
group_output
=
static_cast
<
FP8_
TYPE
*>
(
output_q
)
+
block_group_offset
;
DST_D
TYPE
*
group_output
=
static_cast
<
DST_D
TYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
(
block_group_id
+
local_group_id
);
float
*
scale_output
=
output_s
+
(
block_group_id
+
local_group_id
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
...
@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel(
...
@@ -60,7 +59,7 @@ __global__ void per_token_group_quant_fp8_kernel(
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
const
float
y_s
=
local_absmax
/
fp8_max
;
const
float
y_s
=
local_absmax
/
max_8bit
;
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s
;
*
scale_output
=
y_s
;
...
@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel(
...
@@ -73,20 +72,20 @@ __global__ void per_token_group_quant_fp8_kernel(
#pragma unroll
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
fp8_min
),
fp8_max
);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
min_8bit
),
max_8bit
);
group_output
[
i
*
vec_size
+
j
]
=
FP8_
TYPE
(
q_val
);
group_output
[
i
*
vec_size
+
j
]
=
DST_D
TYPE
(
q_val
);
}
}
}
}
}
}
void
sgl_per_token_group_quant_
fp8
(
void
sgl_per_token_group_quant_
8bit
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
fp8_min
,
double
min_8bit
,
double
fp8_max
)
{
double
max_8bit
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
CHECK_INPUT
(
output_s
);
...
@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8(
...
@@ -111,36 +110,58 @@ void sgl_per_token_group_quant_fp8(
groups_per_block
=
2
;
groups_per_block
=
2
;
}
}
#define LAUNCH_KERNEL(T, GPB) \
auto
dst_type
=
output_q
.
scalar_type
();
do { \
const
int
num_blocks
=
num_groups
/
groups_per_block
;
constexpr int GROUPS_PER_BLOCK = GPB; \
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
dim3 grid((num_groups + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); \
dim3 block(GROUPS_PER_BLOCK* THREADS_PER_GROUP); \
#define LAUNCH_KERNEL(T, DST_DTYPE) \
per_token_group_quant_fp8_kernel<T, GROUPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
do { \
static_cast<T*>(input.data_ptr()), \
dim3 grid(num_blocks); \
output_q.data_ptr(), \
dim3 block(num_threads); \
static_cast<float*>(output_s.data_ptr()), \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
group_size, \
static_cast<T*>(input.data_ptr()), \
num_groups, \
output_q.data_ptr(), \
(float)eps, \
static_cast<float*>(output_s.data_ptr()), \
(float)fp8_min, \
group_size, \
(float)fp8_max); \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} while (0)
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
groups_per_block
==
16
)
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
(
scalar_t
,
16
);
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
}
else
if
(
groups_per_block
==
8
)
{
return
true
;
LAUNCH_KERNEL
(
scalar_t
,
8
);
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
groups_per_block
==
4
)
{
LAUNCH_KERNEL
(
scalar_t
,
c10
::
Float8_e4m3fn
);
LAUNCH_KERNEL
(
scalar_t
,
4
);
return
true
;
}
else
if
(
groups_per_block
==
2
)
{
LAUNCH_KERNEL
(
scalar_t
,
2
);
}
else
{
LAUNCH_KERNEL
(
scalar_t
,
1
);
}
}
return
tru
e
;
return
fals
e
;
});
});
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL
}
}
void
sgl_per_token_group_quant_int8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
);
}
sgl-kernel/csrc/torch_extension.cc
View file @
65c24c28
...
@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -98,6 +98,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
" float eps, float fp8_min, float fp8_max) -> ()"
);
" float eps, float fp8_min, float fp8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
65c24c28
...
@@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8(
...
@@ -141,6 +141,14 @@ void sgl_per_token_group_quant_fp8(
double
eps
,
double
eps
,
double
fp8_min
,
double
fp8_min
,
double
fp8_max
);
double
fp8_max
);
void
sgl_per_token_group_quant_int8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
cublas_grouped_gemm
(
void
cublas_grouped_gemm
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
65c24c28
...
@@ -31,6 +31,7 @@ from sgl_kernel.gemm import (
...
@@ -31,6 +31,7 @@ from sgl_kernel.gemm import (
int8_scaled_mm
,
int8_scaled_mm
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
)
)
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
65c24c28
...
@@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8(
...
@@ -96,6 +96,20 @@ def sgl_per_token_group_quant_fp8(
)
)
def
sgl_per_token_group_quant_int8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
def
sgl_per_tensor_quant_fp8
(
def
sgl_per_tensor_quant_fp8
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/setup.py
View file @
65c24c28
...
@@ -153,7 +153,7 @@ sources = [
...
@@ -153,7 +153,7 @@ sources = [
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/per_token_group_quant_
fp8
.cu"
,
"csrc/gemm/per_token_group_quant_
8bit
.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
...
...
sgl-kernel/tests/test_per_token_group_quant_
fp8
.py
→
sgl-kernel/tests/test_per_token_group_quant_
8bit
.py
View file @
65c24c28
import
itertools
import
itertools
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Tuple
import
pytest
import
pytest
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
# Pointers to inputs and output
y_ptr
,
y_ptr
,
y_q_ptr
,
y_q_ptr
,
...
@@ -25,16 +25,15 @@ def _per_token_group_quant_fp8(
...
@@ -25,16 +25,15 @@ def _per_token_group_quant_fp8(
N
,
N
,
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for
float8
# Information for
8bit data type (int8 or fp8_type_)
fp8_min
,
max_8bit
,
fp8_max
,
min_8bit
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
):
"""A Triton-accelerated function to perform per-token-group quantization on a
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
"""
"""
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
g_id
=
tl
.
program_id
(
0
)
...
@@ -48,30 +47,27 @@ def _per_token_group_quant_fp8(
...
@@ -48,30 +47,27 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_
fp8
(
def
triton_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
quantized tensor along with the scaling factor used for quantization.
Args:
Args:
x: The input tenosr with ndim >= 2.
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
"""
...
@@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8(
...
@@ -80,12 +76,16 @@ def triton_per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
if
dst_dtype
==
torch
.
int8
:
fp8_max
=
finfo
.
max
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
fp8_min
=
-
fp8_max
min_8bit
=
iinfo
.
min
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
M
=
x
.
numel
()
//
group_size
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
...
@@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8(
...
@@ -98,15 +98,15 @@ def triton_per_token_group_quant_fp8(
# heuristics for number of warps
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
group_size
,
group_size
,
N
,
N
,
eps
,
eps
,
fp8_min
=
fp8_min
,
max_8bit
,
fp8_max
=
fp8_max
,
min_8bit
,
BLOCK
=
BLOCK
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8(
...
@@ -115,53 +115,58 @@ def triton_per_token_group_quant_fp8(
return
x_q
,
x_s
return
x_q
,
x_s
def
sglang_per_token_group_quant_
fp8
(
def
sglang_per_token_group_quant_
8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
):
assert
(
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
device
=
x
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
return
x_q
,
x_s
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"batch_size, seq_len, group_size"
,
"batch_size, seq_len, group_size
, dst_dtype
"
,
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
],
# batch_size
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
],
# batch_size
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
16
,
32
,
64
,
128
,
256
],
# group_size
[
16
,
32
,
64
,
128
,
256
],
# group_size
[
torch
.
int8
,
fp8_type_
],
# dtype
)
)
),
),
)
)
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
):
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
x
=
torch
.
randn
(
x
=
torch
.
randn
(
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_
fp8
(
x
,
group_size
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_
8bit
(
x
,
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_
fp8
(
x
,
group_size
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_
8bit
(
x
,
group_size
,
dst_dtype
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment