Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5c66c442
Unverified
Commit
5c66c442
authored
Jun 13, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Support new DeepGEMM format in per token group quant (#7146)
parent
aa46ed34
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
44 deletions
+92
-44
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+83
-40
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+2
-1
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+4
-1
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
5c66c442
...
@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
m
.
def
(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()"
);
" float eps, float fp8_min, float fp8_max
, bool scale_ue8m0
) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
5c66c442
...
@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
...
@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return
val
;
return
val
;
}
}
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
>
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
__global__
void
per_token_group_quant_8bit_kernel
(
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
void
*
__restrict__
output_q
,
floa
t
*
__restrict__
output_s
,
scale_packed_
t
*
__restrict__
output_s
,
const
int
group_size
,
const
int
group_size
,
const
int
num_groups
,
const
int
num_groups
,
const
int
groups_per_block
,
const
int
groups_per_block
,
...
@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel(
...
@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel(
float
local_absmax
=
eps
;
float
local_absmax
=
eps
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
const
T
*
group_input
=
input
+
block_group_offset
;
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
floa
t
*
scale_output
;
scale_element_
t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
row_idx
=
global_group_id
/
scale_num_rows
;
const
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
const
int
col_idx
=
global_group_id
%
scale_num_rows
;
const
int
scale_num_rows_element
=
scale_num_rows
*
num_elems_per_pack
;
scale_output
=
output_s
+
(
col_idx
*
scale_stride
+
row_idx
);
const
int
row_idx
=
global_group_id
/
scale_num_rows_element
;
const
int
col_idx_raw
=
global_group_id
%
scale_num_rows_element
;
const
int
col_idx
=
col_idx_raw
/
num_elems_per_pack
;
const
int
pack_idx
=
col_idx_raw
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
col_idx
*
scale_stride
*
num_elems_per_pack
+
row_idx
*
num_elems_per_pack
+
pack_idx
);
}
else
{
}
else
{
static_assert
(
!
SCALE_UE8M0
);
scale_output
=
output_s
+
global_group_id
;
scale_output
=
output_s
+
global_group_id
;
}
}
...
@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel(
...
@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel(
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
const
float
y_s
=
local_absmax
/
max_8bit
;
float
y_s
=
local_absmax
/
max_8bit
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
fabsf
(
y_s
),
1e-10
f
))));
}
// TODO can optimize
scale_element_t
y_s_quant
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s_quant
=
(
uint8_t
)(((
int
)
log2f
(
y_s
))
+
127
);
}
else
{
y_s_quant
=
y_s
;
}
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s
;
*
scale_output
=
y_s
_quant
;
}
}
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
...
@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit(
...
@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit(
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
min_8bit
,
double
min_8bit
,
double
max_8bit
)
{
double
max_8bit
,
bool
scale_ue8m0
=
false
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
...
@@ -134,7 +160,21 @@ void sgl_per_token_group_quant_8bit(
...
@@ -134,7 +160,21 @@ void sgl_per_token_group_quant_8bit(
dim3 grid(num_blocks); \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
dim3 block(num_threads); \
if (is_column_major) { \
if (is_column_major) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<uint32_t*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
static_cast<float*>(output_s.data_ptr()), \
...
@@ -146,7 +186,9 @@ void sgl_per_token_group_quant_8bit(
...
@@ -146,7 +186,9 @@ void sgl_per_token_group_quant_8bit(
(float)max_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_num_rows, \
scale_stride); \
scale_stride); \
} \
} else { \
} else { \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
output_q.data_ptr(), \
...
@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8(
...
@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8(
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
fp8_min
,
double
fp8_min
,
double
fp8_max
)
{
double
fp8_max
,
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
);
bool
scale_ue8m0
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
);
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
5c66c442
...
@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8(
...
@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8(
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
fp8_min
,
double
fp8_min
,
double
fp8_max
);
double
fp8_max
,
bool
scale_ue8m0
);
void
sgl_per_token_group_quant_int8
(
void
sgl_per_token_group_quant_int8
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_q
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
5c66c442
...
@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8(
...
@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8(
eps
:
float
,
eps
:
float
,
fp8_min
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
fp8_max
:
float
,
scale_ue8m0
:
bool
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
)
...
...
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
5c66c442
...
@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit(
...
@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit(
f8_info
=
torch
.
finfo
(
dtype
)
f8_info
=
torch
.
finfo
(
dtype
)
fp8_max
=
f8_info
.
max
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
scale_ue8m0
=
False
# TODO also test true
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment