Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
339f8eef
Unverified
Commit
339f8eef
authored
Sep 05, 2025
by
fzyzcjy
Committed by
GitHub
Sep 05, 2025
Browse files
[1/2] Optimizations and refactors about quant kernel (#9534)
parent
afd9f2f5
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1002 additions
and
335 deletions
+1002
-335
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+4
-2
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+29
-8
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+8
-2
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+203
-48
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-8
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+447
-177
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-12
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-2
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-18
sgl-kernel/python/sgl_kernel/test_utils.py
sgl-kernel/python/sgl_kernel/test_utils.py
+125
-0
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+161
-58
No files found.
python/sglang/srt/bench_utils.py
View file @
339f8eef
import
os
import
re
import
sys
from
contextlib
import
nullcontext
...
...
@@ -108,7 +109,8 @@ def bench_kineto(
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
sum
([
int
(
re
.
search
(
name
,
line
)
is
not
None
)
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
...
...
@@ -122,7 +124,7 @@ def bench_kineto(
total_time
=
0
total_num
=
0
for
line
in
prof_lines
:
if
name
in
li
ne
:
if
re
.
search
(
name
,
line
)
is
not
No
ne
:
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
339f8eef
...
...
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
from
sgl_kernel
import
(
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
)
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
# Temporary
try
:
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit
=
True
except
ImportError
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit
=
False
if
_is_hip
:
if
_use_aiter
:
...
...
@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
)
if
x
.
shape
[
0
]
>
0
:
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
# Temporary
if
enable_sgl_per_token_group_quant_8bit
:
sgl_per_token_group_quant_8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
,
fuse_silu_and_mul
,
masked_m
,
)
else
:
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
return
x_q
,
x_s
...
...
python/sglang/srt/layers/quantization/int8_kernel.py
View file @
339f8eef
...
...
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_token_group_quant_int8
# Temporary
try
:
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
except
ImportError
:
from
sgl_kernel
import
(
sgl_per_token_group_quant_int8
as
sgl_per_token_group_quant_8bit
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_
int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
sgl_per_token_group_quant_
8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
return
x_q
,
x_s
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
os
import
time
from
functools
import
partial
from
pathlib
import
Path
import
torch
import
triton
from
sgl_kernel.test_utils
import
create_per_token_group_quant_test_data
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
mode_concentrated
=
os
.
environ
.
get
(
"SGLANG_BENCH_MODE"
,
""
)
==
"concentrated"
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# TODO test int8
dst_dtype_range
=
[
fp8_type_
]
flags_range
=
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
),
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
hidden_dim_range
,
group_size_range
,
dst_dtype_range
,
flags_range
,
if
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
"0"
)):
# configs = [[
# 768,
# 16384,
# 128,
# None,
# fp8_type_,
# dict(
# column_major_scales=True,
# scale_tma_aligned=True,
# scale_ue8m0=True,
# fuse_silu_and_mul=False,
# masked_layout_mode=None,
# ),
# ]]
configs
=
[
[
768
*
8
,
2048
,
128
,
48
,
fp8_type_
,
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
# masked_layout_mode=None,
masked_layout_mode
=
"balanced"
,
# masked_layout_mode="extreme",
),
]
]
elif
mode_concentrated
:
configs
=
list
(
itertools
.
product
(
[
768
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
768
*
8
],
[
2048
],
[
128
],
[
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
else
:
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"dst_dtype"
,
"flags"
],
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"num_ranks"
,
"dst_dtype"
,
"flags"
,
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
# Triton has multi kernels and we only report the time for the core one
line_names
=
[
"Triton (Inaccurate)"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
flags
,
provider
):
if
flags
[
"scale_ue8m0"
]
and
group_size
!=
128
:
return
device
=
torch
.
device
(
"cuda"
)
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
num_ranks
,
dst_dtype
,
flags
,
provider
):
print
(
f
"Testing:
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
{
provider
=
}
"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
fn
,
kernel_names
=
{
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_fp8"
),
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel"
,
),
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
),
}[
provider
]
bench_fn
=
lambda
:
fn
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
flags
)
bench_fn
=
lambda
:
fn
(
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]},
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
,
num_tests
=
300
if
mode_concentrated
else
30
)
return
time_s
*
1e6
...
...
sgl-kernel/csrc/common_extension.cc
View file @
339f8eef
...
...
@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
def
(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_8bit"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_8bit
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
339f8eef
#include <ATen/cuda/CUDAContext.h>
#include <c
uda_fp8
.h>
#include <c
10/util/Float8_e4m3fn
.h>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template
<
int
THREADS_PER_SUBWARP
>
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
0xffff
;
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
static_assert
(
(
THREADS_PER_SUBWARP
&
(
THREADS_PER_SUBWARP
-
1
))
==
0
&&
THREADS_PER_SUBWARP
<=
16
&&
THREADS_PER_SUBWARP
>=
1
,
"THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16"
);
if
constexpr
(
THREADS_PER_SUBWARP
>=
16
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
8
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
4
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
2
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
}
return
val
;
}
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
scale_packed_t
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
int
groups_per_block
,
const
float
eps
,
const
float
min_8bit
,
const
float
max_8bit
,
const
int
num_groups_per_row
=
0
,
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int64_t
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float
half
=
0.5
f
*
val
;
float
t
=
__tanhf
(
half
);
return
half
*
(
1.0
f
+
t
);
#else
return
val
/
(
1.0
f
+
__expf
(
-
val
));
#endif
}
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
__device__
float2
fmul2_rn
(
float2
a
,
float2
b
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return
__fmul2_rn
(
a
,
b
);
#else
float2
result
;
result
.
x
=
a
.
x
*
b
.
x
;
result
.
y
=
a
.
y
*
b
.
y
;
return
result
;
#endif
}
// Copied and modified from DeepEP
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
// Copied and modified from DeepEP
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
scale_element_t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
const
int
row_idx
=
global_group_id
/
num_groups_per_row
;
const
int
col_idx_unpacked
=
global_group_id
%
num_groups_per_row
;
const
int
col_idx
=
col_idx_unpacked
/
num_elems_per_pack
;
const
int
pack_idx
=
col_idx_unpacked
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
col_idx
*
scale_stride
*
num_elems_per_pack
+
row_idx
*
num_elems_per_pack
+
pack_idx
);
// Copied and modified from DeepEP
template
<
bool
ROUND_SCALE
,
typename
dtype_info
>
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
constexpr
float
MAX_8BIT_INV
=
1.0
f
/
dtype_info
::
MAX
;
if
constexpr
(
ROUND_SCALE
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
MAX_8BIT_INV
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
s
tatic_assert
(
!
SCALE_UE8M0
)
;
scale
_output
=
output_s
+
global_group_id
;
s
cale_inv
=
amax
*
MAX_8BIT_INV
;
scale
=
dtype_info
::
MAX
/
amax
;
}
}
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
// Copied and modified from DeepEP
template
<
bool
SCALE_UE8M0
,
typename
OUT_DTYPE_T
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
OUT_DTYPE_T
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
SCALE_UE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
const
int32_t
num_vec_elems
=
group_size
/
vec_size
;
__device__
__forceinline__
void
st_global
(
const
int4
*
ptr
,
const
int4
&
value
)
{
asm
volatile
(
"st.global.v4.s32 [%0], {%1, %2, %3, %4};"
::
"l"
(
ptr
),
"r"
(
value
.
x
),
"r"
(
value
.
y
),
"r"
(
value
.
z
),
"r"
(
value
.
w
));
}
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
__device__
__forceinline__
int4
ld_global_nc
(
const
int4
*
ptr
)
{
int4
ret
;
asm
volatile
(
"ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
ret
.
x
),
"=r"
(
ret
.
y
),
"=r"
(
ret
.
z
),
"=r"
(
ret
.
w
)
:
"l"
(
ptr
));
return
ret
;
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
template
<
typename
T
>
struct
DtypeInfo
;
template
<
>
struct
DtypeInfo
<
int8_t
>
{
static
constexpr
float
MIN
=
-
128
;
static
constexpr
float
MAX
=
127
;
};
template
<
>
struct
DtypeInfo
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
float
MIN
=
-
448
;
static
constexpr
float
MAX
=
448
;
};
template
<
bool
FUSE_SILU_AND_MUL
>
__device__
__forceinline__
int
compute_input_group_start_offset
(
int
expert_idx
,
int
token_idx
,
int
hidden_dim_group_idx
,
int
hidden_size
,
int
num_tokens_per_expert
,
int
group_size
)
{
return
expert_idx
*
num_tokens_per_expert
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
token_idx
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
hidden_dim_group_idx
*
group_size
;
}
constexpr
float
LOCAL_ABSMAX_ABS
=
1e-10
;
constexpr
uint32_t
INPUT_PRIMARY_VEC_NUM_BYTES
=
32
;
struct
NaiveScheduler
{
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
([
=
]()
->
int
{
if
(
num_groups
%
16
==
0
)
{
return
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
return
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
return
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
return
2
;
}
return
1
;
})();
grid
=
dim3
(
num_groups
/
subwarps_per_block
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
constexpr
int
expert_idx
=
0
;
float
y_s
=
local_absmax
/
max_8bit
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
y_s
,
1e-10
f
))));
}
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
// TODO can optimize
scale_element_t
y_s_quant
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s_quant
=
(
uint8_t
)(((
int
)
log2f
(
y_s
))
+
127
);
}
else
{
y_s_quant
=
y_s
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
subwarps_per_block
;
const
int64_t
group_id
=
block_group_id
+
subwarp_id
;
int64_t
input_group_start_offset
;
if
constexpr
(
!
FUSE_SILU_AND_MUL
)
{
input_group_start_offset
=
group_id
*
GROUP_SIZE
;
}
const
int
token_idx
=
group_id
/
hidden_dim_num_groups
;
// At the hidden_size dimension, we are handling idx-th group
const
int
hidden_dim_group_idx
=
group_id
%
hidden_dim_num_groups
;
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
}
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
};
struct
MaskedLayoutScheduler
{
// TODO can be dynamically determined (which may be good when num rank is small)
static
constexpr
int
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
=
1024
;
static
constexpr
int
SUBWARPS_PER_BLOCK
=
16
;
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
SUBWARPS_PER_BLOCK
;
TORCH_CHECK
(
hidden_dim_num_groups
%
subwarps_per_block
==
0
);
grid
=
dim3
(
hidden_dim_num_groups
/
subwarps_per_block
,
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
,
num_local_experts
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s_quant
;
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
const
int
expert_idx
=
blockIdx
.
z
;
const
int
token_idx_start
=
blockIdx
.
y
;
const
int64_t
hidden_dim_group_idx
=
blockIdx
.
x
*
SUBWARPS_PER_BLOCK
+
subwarp_id
;
const
int
curr_expert_token_num
=
masked_m
[
expert_idx
];
for
(
int
token_idx
=
token_idx_start
;
token_idx
<
curr_expert_token_num
;
token_idx
+=
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
const
int64_t
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
}
};
template
<
typename
SCHEDULER
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
bool
FUSE_SILU_AND_MUL
=
false
,
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
DST_DTYPE
*
__restrict__
output_q
,
scale_packed_t
*
__restrict__
output_s
,
const
int32_t
*
__restrict__
masked_m
,
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
// TODO can this be removed?
const
int
scale_expert_stride
,
const
int
scale_hidden_stride
,
const
int
num_tokens_per_expert
)
{
using
dst_dtype_info
=
DtypeInfo
<
DST_DTYPE
>
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
SCHEDULER
::
execute
<
FUSE_SILU_AND_MUL
,
GROUP_SIZE
,
THREADS_PER_SUBWARP
>
(
subwarps_per_block
,
hidden_dim_num_groups
,
masked_m
,
num_tokens_per_expert
,
[
&
](
const
int
expert_idx
,
const
int
token_idx
,
const
int
hidden_dim_group_idx
,
const
int
lane_id
,
const
int
input_group_start_offset
)
{
constexpr
uint32_t
INPUT_PRIMARY_VEC_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
T
);
constexpr
uint32_t
INPUT_PRIMARY_INT4_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
int4
);
const
int
offset_num_groups
=
expert_idx
*
num_tokens_per_expert
*
hidden_dim_num_groups
+
token_idx
*
hidden_dim_num_groups
+
hidden_dim_group_idx
;
int4
input_primary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_primary_vec
=
reinterpret_cast
<
T
*>
(
input_primary_int4
);
static_assert
(
sizeof
(
input_primary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_primary_int4
));
int4
input_secondary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_secondary_vec
=
reinterpret_cast
<
T
*>
(
input_secondary_int4
);
static_assert
(
sizeof
(
input_secondary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_secondary_int4
));
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
min_8bit
),
max_8bit
);
group_output
[
i
*
vec_size
+
j
]
=
DST_DTYPE
(
q_val
);
}
}
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
input_primary_int4
[
j
]
=
ld_global_nc
(
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
)
+
j
);
}
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
const
int
secondary_offset
=
hidden_dim_num_groups
*
GROUP_SIZE
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
input_secondary_int4
[
j
]
=
ld_global_nc
(
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
+
secondary_offset
)
+
j
);
}
}
constexpr
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
scale_element_t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
constexpr
int
scale_token_stride
=
1
;
const
int
hidden_idx_packed
=
hidden_dim_group_idx
/
num_elems_per_pack
;
const
int
pack_idx
=
hidden_dim_group_idx
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
expert_idx
*
scale_expert_stride
*
num_elems_per_pack
+
hidden_idx_packed
*
scale_hidden_stride
*
num_elems_per_pack
+
token_idx
*
scale_token_stride
*
num_elems_per_pack
+
pack_idx
);
}
else
{
static_assert
(
!
SCALE_UE8M0
);
scale_output
=
output_s
+
offset_num_groups
;
}
// can speed up if too slow
if
constexpr
(
IS_COLUMN_MAJOR
and
SCALE_UE8M0
)
{
const
int
remainder_num_groups
=
hidden_dim_num_groups
%
num_elems_per_pack
;
if
((
remainder_num_groups
!=
0
)
and
(
hidden_dim_group_idx
==
hidden_dim_num_groups
-
1
)
and
(
lane_id
<
num_elems_per_pack
-
remainder_num_groups
))
{
const
int
shift
=
1
+
lane_id
;
*
(
scale_output
+
shift
)
=
0
;
}
}
float
local_absmax
=
LOCAL_ABSMAX_ABS
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
float
val
;
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
// TODO maybe vectorize
T
val_lowprec
=
static_cast
<
T
>
(
silu
(
static_cast
<
float
>
(
input_primary_vec
[
j
])))
*
input_secondary_vec
[
j
];
val
=
static_cast
<
float
>
(
val_lowprec
);
input_primary_vec
[
j
]
=
val_lowprec
;
}
else
{
val
=
static_cast
<
float
>
(
input_primary_vec
[
j
]);
}
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
local_absmax
=
GroupReduceMax
<
THREADS_PER_SUBWARP
>
(
local_absmax
,
lane_id
);
float
y_scale
,
y_scale_inv
;
calculate_fp8_scales
<
SCALE_UE8M0
,
dst_dtype_info
>
(
local_absmax
,
y_scale
,
y_scale_inv
);
float2
y_scale_repeated
=
{
y_scale
,
y_scale
};
if
(
lane_id
==
0
)
{
*
scale_output
=
extract_required_scale_format
<
SCALE_UE8M0
>
(
y_scale_inv
);
}
int4
output_buf
;
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
*
sizeof
(
DST_DTYPE
));
if
constexpr
(
std
::
is_same_v
<
DST_DTYPE
,
c10
::
Float8_e4m3fn
>
)
{
const
auto
output_buf_ptr
=
reinterpret_cast
<
__nv_fp8x2_storage_t
*>
(
&
output_buf
);
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
/
2
*
sizeof
(
__nv_fp8x2_storage_t
));
static_assert
(
INPUT_PRIMARY_VEC_SIZE
%
2
==
0
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
j
+=
2
)
{
float2
inputx2
=
{
static_cast
<
float
>
(
input_primary_vec
[
j
]),
static_cast
<
float
>
(
input_primary_vec
[
j
+
1
])};
float2
outputx2
=
fmul2_rn
(
inputx2
,
y_scale_repeated
);
output_buf_ptr
[
j
/
2
]
=
__nv_cvt_float2_to_fp8x2
(
outputx2
,
__NV_SATFINITE
,
__NV_E4M3
);
}
}
else
{
const
auto
output_buf_ptr
=
reinterpret_cast
<
DST_DTYPE
*>
(
&
output_buf
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_primary_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
*
y_scale
,
dst_dtype_info
::
MIN
),
dst_dtype_info
::
MAX
);
output_buf_ptr
[
j
]
=
DST_DTYPE
(
q_val
);
}
}
st_global
(
reinterpret_cast
<
int4
*>
(
output_q
+
offset_num_groups
*
GROUP_SIZE
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
),
output_buf
);
});
}
void
sgl_per_token_group_quant_8bit
(
// vanilla: (num_tokens, hidden_size)
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
...
...
@@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit(
double
eps
,
double
min_8bit
,
double
max_8bit
,
bool
scale_ue8m0
=
false
)
{
bool
scale_ue8m0
,
bool
fuse_silu_and_mul
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
TORCH_CHECK
(
input
.
numel
()
>
0
);
const
int
num_groups
=
input
.
numel
()
/
group_size
;
TORCH_CHECK
(
std
::
abs
(
LOCAL_ABSMAX_ABS
-
eps
)
<
1e-13
)
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
output_s
.
dim
(),
2
);
const
int
num_groups
=
static_cast
<
int
>
(
input
.
numel
())
/
group_size
/
(
fuse_silu_and_mul
?
2
:
1
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
constexpr
int
THREADS_PER_GROUP
=
16
;
const
bool
masked_layout
=
masked_m
.
has_value
();
TORCH_CHECK
(
output_s
.
dim
()
==
(
masked_layout
?
3
:
2
));
int
groups_per_block
=
1
;
const
int
num_local_experts
=
masked_layout
?
input
.
size
(
0
)
:
1
;
if
(
num_groups
%
16
==
0
)
{
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
const
int
hidden_dim
=
input
.
size
(
input
.
dim
()
-
1
);
const
int
num_groups_per_row
=
hidden_dim
/
group_size
;
const
int
scale_stride
=
output_s
.
stride
(
1
);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<uint32_t*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} \
} else { \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
const
bool
is_column_major
=
output_s
.
stride
(
-
2
)
<
output_s
.
stride
(
-
1
);
const
int
hidden_dim_num_groups
=
static_cast
<
int
>
(
output_q
.
size
(
-
1
))
/
group_size
;
const
int
num_tokens_per_expert
=
static_cast
<
int
>
(
output_q
.
size
(
-
2
));
const
int
scale_expert_stride
=
masked_layout
?
static_cast
<
int
>
(
output_s
.
stride
(
0
))
:
0
;
const
int
scale_hidden_stride
=
static_cast
<
int
>
(
output_s
.
stride
(
-
1
));
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \
do { \
int subwarps_per_block; \
dim3 grid, block; \
SCHEDULER::compute_exec_config( \
THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
\
per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
<<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
static_cast<DST_DTYPE*>(output_q.data_ptr()), \
static_cast<output_s_dtype*>(output_s.data_ptr()), \
static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
subwarps_per_block, \
hidden_dim_num_groups, \
scale_expert_stride, \
scale_hidden_stride, \
num_tokens_per_expert); \
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \
do { \
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \
\
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
\
if (is_column_major) { \
if (scale_ue8m0) { \
if (fuse_silu_and_mul) { \
if (masked_layout) { \
LAUNCH_KERNEL_INNER( \
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} else { \
LAUNCH_KERNEL_INNER( \
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \
} \
} while (0)
#define LAUNCH_KERNEL_OUTER(...) \
switch (group_size) { \
case 16: \
LAUNCH_KERNEL(16, __VA_ARGS__); \
break; \
case 32: \
LAUNCH_KERNEL(32, __VA_ARGS__); \
break; \
case 64: \
LAUNCH_KERNEL(64, __VA_ARGS__); \
break; \
case 128: \
LAUNCH_KERNEL(128, __VA_ARGS__); \
break; \
default: \
TORCH_CHECK(false, "Unsupported group_size"); \
} \
while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
int8_t
);
return
true
;
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
(
scalar_t
,
__nv_fp
8_e4m3
);
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
c10
::
Float
8_e4m3
fn
);
return
true
;
}
return
false
;
});
#undef LAUNCH_KERNEL
}
void
sgl_per_token_group_quant_int8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
);
#undef LAUNCH_KERNEL_INNER
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
339f8eef
...
...
@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
const
torch
::
Dtype
&
out_dtype
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
sgl_per_token_group_quant_
fp8
(
void
sgl_per_token_group_quant_
8bit
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
);
void
sgl_per_token_group_quant_int8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
double
min_8bit
,
double
max_8bit
,
bool
scale_ue8m0
,
bool
fuse_silu_and_mul
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
bmm_fp8
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
339f8eef
...
...
@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_group_quant_8bit
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
silu_and_mul_scaled_fp4_grouped_quant
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
339f8eef
...
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return
output
def
sgl_per_token_group_quant_
fp8
(
def
sgl_per_token_group_quant_
8bit
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
...
...
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
scale_ue8m0
:
bool
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
def
sgl_per_token_group_quant_int8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_8bit
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
,
fuse_silu_and_mul
,
masked_m
,
)
...
...
sgl-kernel/python/sgl_kernel/test_utils.py
0 → 100644
View file @
339f8eef
import
torch
def
create_per_token_group_quant_test_data
(
num_tokens
,
hidden_dim
,
num_ranks
,
flags
):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
seed
=
num_tokens
*
10000
+
hidden_dim
gen_cpu
=
torch
.
Generator
(
device
=
"cpu"
)
gen_cpu
.
manual_seed
(
seed
)
gen_cuda
=
torch
.
Generator
(
device
=
"cuda"
)
gen_cuda
.
manual_seed
(
seed
)
if
flags
[
"fuse_silu_and_mul"
]:
effective_hidden_dim
=
hidden_dim
*
2
else
:
effective_hidden_dim
=
hidden_dim
del
hidden_dim
if
(
masked_layout_mode
:
=
flags
[
"masked_layout_mode"
])
is
not
None
:
num_max_dispatch_tokens_per_rank
=
768
num_global_experts
=
288
num_local_experts
,
remainder
=
divmod
(
num_global_experts
,
num_ranks
)
assert
remainder
==
0
# mimic DeepEP low_latency_dispatch output
x
=
torch
.
randn
(
num_local_experts
,
num_max_dispatch_tokens_per_rank
*
num_ranks
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
if
masked_layout_mode
==
"balanced"
:
masked_m
=
_compute_balanced_split
(
num_tokens
,
num_local_experts
)
elif
masked_layout_mode
==
"imbalanced"
:
masked_m
=
_compute_imbalanced_split
(
num_tokens
,
num_local_experts
,
gen_cpu
=
gen_cpu
)
elif
masked_layout_mode
==
"extreme"
:
masked_m
=
torch
.
tensor
(
[
num_tokens
]
+
[
0
]
*
(
num_local_experts
-
1
),
dtype
=
torch
.
int
)
else
:
raise
NotImplementedError
print
(
f
"
{
masked_layout_mode
=
}
{
masked_m
=
}
{
x
.
shape
=
}
"
)
masked_m
=
masked_m
.
to
(
device
)
return
x
,
masked_m
else
:
x
=
torch
.
randn
(
num_tokens
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
x
[
torch
.
randn
(
x
.
shape
,
device
=
device
,
generator
=
gen_cuda
)
<
0.001
]
*=
10
return
x
,
None
def
_compute_balanced_split
(
total
:
int
,
arr_len
:
int
):
base
=
total
//
arr_len
remainder
=
total
%
arr_len
ans
=
[
base
+
1
if
i
<
remainder
else
base
for
i
in
range
(
arr_len
)]
assert
sum
(
ans
)
==
total
return
torch
.
tensor
(
ans
,
dtype
=
torch
.
int
)
def
_compute_imbalanced_split
(
total
:
int
,
arr_len
:
int
,
gen_cpu
,
dtype
=
torch
.
int
)
->
list
[
int
]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw
=
torch
.
rand
(
arr_len
,
generator
=
gen_cpu
)
**
3
noise
=
noise_raw
/
noise_raw
.
sum
()
ans
=
(
noise
*
total
).
round
().
to
(
dtype
)
diff
=
total
-
ans
.
sum
().
item
()
while
diff
!=
0
:
idx
=
torch
.
randint
(
0
,
arr_len
,
(
1
,),
generator
=
gen_cpu
).
item
()
if
diff
>
0
:
ans
[
idx
]
+=
1
diff
-=
1
elif
diff
<
0
and
ans
[
idx
]
>
0
:
ans
[
idx
]
-=
1
diff
+=
1
assert
sum
(
ans
)
==
total
return
ans
def
assert_all_close_or_tiny_diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
(
a
.
shape
==
b
.
shape
)
and
(
a
.
dtype
==
b
.
dtype
),
f
"
{
a
.
shape
=
}
{
b
.
shape
=
}
{
a
.
dtype
=
}
{
b
.
dtype
=
}
"
numel
=
a
.
numel
()
if
a
.
dtype
==
torch
.
float8_e4m3fn
:
a_u8
=
a
.
view
(
torch
.
uint8
)
b_u8
=
b
.
view
(
torch
.
uint8
)
diff_u8
=
(
a_u8
.
to
(
torch
.
int16
)
-
b_u8
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a_u8
>=
0
)
&
(
b_u8
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff_u8
==
1
).
sum
().
item
()
count_large_diff
=
(
diff_u8
>=
2
).
sum
().
item
()
elif
a
.
dtype
==
torch
.
int8
:
diff
=
(
a
.
to
(
torch
.
int16
)
-
a
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a
>=
0
)
&
(
b
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff
==
1
).
sum
().
item
()
count_large_diff
=
(
diff
>=
2
).
sum
().
item
()
else
:
raise
NotImplementedError
assert
(
(
count_diff_sign
==
0
)
and
(
count_large_diff
==
0
)
and
(
(
count_tiny_diff
/
numel
<
0.005
)
or
((
count_tiny_diff
/
numel
<
0.04
)
and
(
numel
<=
4096
))
)
),
f
"
{
count_diff_sign
=
}
{
count_tiny_diff
=
}
{
count_large_diff
=
}
{
numel
=
}
{
a
=
}
{
b
=
}
"
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
os
import
time
from
pathlib
import
Path
import
pytest
import
torch
from
sgl_kernel.test_utils
import
(
assert_all_close_or_tiny_diff
,
create_per_token_group_quant_test_data
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.utils
import
assert_fp8_all_close
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
128
,
256
,
384
,
512
,
1024
,
1536
,
1664
,
2048
,
4096
,
7168
,
16384
],
# hidden_dim
[
16
,
32
,
64
,
128
],
# group_size
[
None
],
# num_ranks
[
fp8_type_
,
torch
.
int8
],
# dtype
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
1
,
4
,
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
# TODO support more
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, dst_dtype, flags"
,
list
(
itertools
.
product
(
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
8
,
16
,
32
,
64
,
128
],
# group_size
# TODO test int8
[
fp8_type_
],
# dtype
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
),
],
)
),
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags"
,
configs
)
def
test_per_token_group_quant_with_column_major
(
num_tokens
,
hidden_dim
,
group_size
,
num_ranks
,
dst_dtype
,
flags
,
):
if
flags
[
"scale_ue8m0"
]
and
((
group_size
!=
128
)
or
(
hidden_dim
%
512
!=
0
)):
pytest
.
skip
()
print
(
f
"
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
"
)
arch_major
,
_
=
torch
.
cuda
.
get_device_capability
(
torch
.
cuda
.
current_device
())
if
flags
[
"scale_ue8m0"
]
and
(
arch_major
<=
9
):
pytest
.
skip
(
"Only Blackwell need ue8m0 fusion"
)
return
if
flags
[
"scale_ue8m0"
]
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
:
pytest
.
skip
(
"scale_ue8m0 only supported on Blackwell"
)
if
(
flags
[
"scale_ue8m0"
]
and
(
group_size
!=
128
))
or
(
(
dst_dtype
==
torch
.
int8
)
and
flags
[
"column_major_scales"
]
):
pytest
.
skip
()
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs
=
dict
(
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
eps
=
1e-10
,
dst_dtype
=
dst_dtype
,
**
flags
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]}
,
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
# torch.set_printoptions(profile="full")
# print(f"{x_q_triton=}")
# print(f"{x_s_triton=}")
# print(f"{x_q_sglang=}")
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
assert_fp8_all_close
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
def
_postprocess
(
x_q
,
x_s
):
if
masked_m
is
not
None
:
print
(
f
"Mask tokens after
{
masked_m
}
to be zero"
)
for
i
in
range
(
len
(
masked_m
)):
x_q
[
i
,
masked_m
[
i
]
:,
:]
=
0
x_s
[
i
,
masked_m
[
i
]
:,
:]
=
0
return
x_q
,
x_s
x_q_triton
,
x_s_triton
=
_postprocess
(
*
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
)
x_q_sglang
,
x_s_sglang
=
_postprocess
(
*
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
)
try
:
assert_all_close_or_tiny_diff
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
except
AssertionError
:
# torch.set_printoptions(profile="full")
print
(
f
"
{
x
.
shape
=
}
{
x_q_triton
.
shape
=
}
{
x_s_triton
.
shape
=
}
{
x_q_sglang
.
shape
=
}
{
x_s_sglang
.
shape
=
}
"
)
print
(
f
"
{
x
=
}
"
)
print
(
f
"
{
masked_m
=
}
"
)
print
(
f
"
{
x_q_triton
=
}
"
)
print
(
f
"
{
x_s_triton
=
}
"
)
print
(
f
"
{
x_q_sglang
=
}
"
)
print
(
f
"
{
x_s_sglang
=
}
"
)
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment