Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
339f8eef
Unverified
Commit
339f8eef
authored
Sep 05, 2025
by
fzyzcjy
Committed by
GitHub
Sep 05, 2025
Browse files
[1/2] Optimizations and refactors about quant kernel (#9534)
parent
afd9f2f5
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1002 additions
and
335 deletions
+1002
-335
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+4
-2
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+29
-8
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+8
-2
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+203
-48
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-8
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+447
-177
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-12
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-2
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-18
sgl-kernel/python/sgl_kernel/test_utils.py
sgl-kernel/python/sgl_kernel/test_utils.py
+125
-0
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+161
-58
No files found.
python/sglang/srt/bench_utils.py
View file @
339f8eef
import
os
import
os
import
re
import
sys
import
sys
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
...
@@ -108,7 +109,8 @@ def bench_kineto(
...
@@ -108,7 +109,8 @@ def bench_kineto(
if
not
with_multiple_kernels
:
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
for
name
in
kernel_names
:
assert
(
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
sum
([
int
(
re
.
search
(
name
,
line
)
is
not
None
)
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
# Save chrome traces
...
@@ -122,7 +124,7 @@ def bench_kineto(
...
@@ -122,7 +124,7 @@ def bench_kineto(
total_time
=
0
total_time
=
0
total_num
=
0
total_num
=
0
for
line
in
prof_lines
:
for
line
in
prof_lines
:
if
name
in
li
ne
:
if
re
.
search
(
name
,
line
)
is
not
No
ne
:
time_str
=
line
.
split
()[
-
2
]
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
for
unit
,
scale
in
units
.
items
():
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
339f8eef
...
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
...
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
(
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
# Temporary
sgl_per_token_quant_fp8
,
try
:
)
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit
=
True
except
ImportError
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit
=
False
if
_is_hip
:
if
_is_hip
:
if
_use_aiter
:
if
_use_aiter
:
...
@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
...
@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
)
)
if
x
.
shape
[
0
]
>
0
:
if
x
.
shape
[
0
]
>
0
:
sgl_per_token_group_quant_fp8
(
# Temporary
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
if
enable_sgl_per_token_group_quant_8bit
:
)
sgl_per_token_group_quant_8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
,
fuse_silu_and_mul
,
masked_m
,
)
else
:
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
python/sglang/srt/layers/quantization/int8_kernel.py
View file @
339f8eef
...
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
...
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_token_group_quant_int8
# Temporary
try
:
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
except
ImportError
:
from
sgl_kernel
import
(
sgl_per_token_group_quant_int8
as
sgl_per_token_group_quant_8bit
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
...
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
sgl_per_token_group_quant_
int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
sgl_per_token_group_quant_
8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
itertools
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
import
triton
import
triton
from
sgl_kernel.test_utils
import
create_per_token_group_quant_test_data
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
...
@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
mode_concentrated
=
os
.
environ
.
get
(
"SGLANG_BENCH_MODE"
,
""
)
==
"concentrated"
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
if
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
"0"
)):
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
# configs = [[
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# 768,
# TODO test int8
# 16384,
dst_dtype_range
=
[
fp8_type_
]
# 128,
flags_range
=
[
# None,
dict
(
# fp8_type_,
column_major_scales
=
False
,
# dict(
scale_tma_aligned
=
False
,
# column_major_scales=True,
scale_ue8m0
=
False
,
# scale_tma_aligned=True,
),
# scale_ue8m0=True,
dict
(
# fuse_silu_and_mul=False,
column_major_scales
=
True
,
# masked_layout_mode=None,
scale_tma_aligned
=
False
,
# ),
scale_ue8m0
=
False
,
# ]]
),
configs
=
[
dict
(
[
column_major_scales
=
True
,
768
*
8
,
scale_tma_aligned
=
True
,
2048
,
scale_ue8m0
=
False
,
128
,
),
48
,
dict
(
fp8_type_
,
column_major_scales
=
True
,
dict
(
scale_tma_aligned
=
True
,
column_major_scales
=
True
,
scale_ue8m0
=
True
,
scale_tma_aligned
=
True
,
),
scale_ue8m0
=
True
,
]
fuse_silu_and_mul
=
True
,
# masked_layout_mode=None,
masked_layout_mode
=
"balanced"
,
configs
=
list
(
# masked_layout_mode="extreme",
itertools
.
product
(
),
num_tokens_range
,
]
hidden_dim_range
,
]
group_size_range
,
elif
mode_concentrated
:
dst_dtype_range
,
configs
=
list
(
flags_range
,
itertools
.
product
(
[
768
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
768
*
8
],
[
2048
],
[
128
],
[
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
else
:
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
)
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"dst_dtype"
,
"flags"
],
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"num_ranks"
,
"dst_dtype"
,
"flags"
,
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
# Triton has multi kernels and we only report the time for the core one
line_names
=
[
"Triton (Inaccurate)"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
flags
,
provider
):
def
benchmark
(
if
flags
[
"scale_ue8m0"
]
and
group_size
!=
128
:
num_tokens
,
hidden_dim
,
group_size
,
num_ranks
,
dst_dtype
,
flags
,
provider
return
):
print
(
device
=
torch
.
device
(
"cuda"
)
f
"Testing:
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
{
provider
=
}
"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
fn
,
kernel_names
=
{
fn
,
kernel_names
=
{
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_fp8"
),
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel"
,
),
"sglang"
:
(
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
"per_token_group_quant_8bit_kernel"
,
),
),
}[
provider
]
}[
provider
]
bench_fn
=
lambda
:
fn
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
flags
)
bench_fn
=
lambda
:
fn
(
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]},
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
,
num_tests
=
300
if
mode_concentrated
else
30
)
return
time_s
*
1e6
return
time_s
*
1e6
...
...
sgl-kernel/csrc/common_extension.cc
View file @
339f8eef
...
@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
def
(
m
.
def
(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"
);
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
impl
(
"sgl_per_token_group_quant_8bit"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_8bit
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
339f8eef
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c
uda_fp8
.h>
#include <c
10/util/Float8_e4m3fn
.h>
#include <cmath>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#include "utils.h"
template
<
int
THREADS_PER_SUBWARP
>
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
0xffff
;
unsigned
mask
=
0xffff
;
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
static_assert
(
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
(
THREADS_PER_SUBWARP
&
(
THREADS_PER_SUBWARP
-
1
))
==
0
&&
THREADS_PER_SUBWARP
<=
16
&&
THREADS_PER_SUBWARP
>=
1
,
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
"THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16"
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
if
constexpr
(
THREADS_PER_SUBWARP
>=
16
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
8
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
4
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
2
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
}
return
val
;
return
val
;
}
}
template
<
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
typename
T
,
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
typename
DST_DTYPE
,
float
half
=
0.5
f
*
val
;
bool
IS_COLUMN_MAJOR
=
false
,
float
t
=
__tanhf
(
half
);
bool
SCALE_UE8M0
=
false
,
return
half
*
(
1.0
f
+
t
);
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
#else
__global__
void
per_token_group_quant_8bit_kernel
(
return
val
/
(
1.0
f
+
__expf
(
-
val
));
const
T
*
__restrict__
input
,
#endif
void
*
__restrict__
output_q
,
}
scale_packed_t
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
int
groups_per_block
,
const
float
eps
,
const
float
min_8bit
,
const
float
max_8bit
,
const
int
num_groups_per_row
=
0
,
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int64_t
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
__device__
float2
fmul2_rn
(
float2
a
,
float2
b
)
{
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return
__fmul2_rn
(
a
,
b
);
#else
float2
result
;
result
.
x
=
a
.
x
*
b
.
x
;
result
.
y
=
a
.
y
*
b
.
y
;
return
result
;
#endif
}
// Copied and modified from DeepEP
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
// Copied and modified from DeepEP
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
const
T
*
group_input
=
input
+
block_group_offset
;
// Copied and modified from DeepEP
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
template
<
bool
ROUND_SCALE
,
typename
dtype_info
>
scale_element_t
*
scale_output
;
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
constexpr
float
MAX_8BIT_INV
=
1.0
f
/
dtype_info
::
MAX
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
if
constexpr
(
ROUND_SCALE
)
{
const
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
MAX_8BIT_INV
);
const
int
row_idx
=
global_group_id
/
num_groups_per_row
;
scale
=
fast_pow2
(
-
exp_scale_inv
);
const
int
col_idx_unpacked
=
global_group_id
%
num_groups_per_row
;
scale_inv
=
fast_pow2
(
exp_scale_inv
);
const
int
col_idx
=
col_idx_unpacked
/
num_elems_per_pack
;
const
int
pack_idx
=
col_idx_unpacked
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
col_idx
*
scale_stride
*
num_elems_per_pack
+
row_idx
*
num_elems_per_pack
+
pack_idx
);
}
else
{
}
else
{
s
tatic_assert
(
!
SCALE_UE8M0
)
;
s
cale_inv
=
amax
*
MAX_8BIT_INV
;
scale
_output
=
output_s
+
global_group_id
;
scale
=
dtype_info
::
MAX
/
amax
;
}
}
}
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
// Copied and modified from DeepEP
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
template
<
bool
SCALE_UE8M0
,
typename
OUT_DTYPE_T
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
OUT_DTYPE_T
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
SCALE_UE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
const
int32_t
num_vec_elems
=
group_size
/
vec_size
;
__device__
__forceinline__
void
st_global
(
const
int4
*
ptr
,
const
int4
&
value
)
{
asm
volatile
(
"st.global.v4.s32 [%0], {%1, %2, %3, %4};"
::
"l"
(
ptr
),
"r"
(
value
.
x
),
"r"
(
value
.
y
),
"r"
(
value
.
z
),
"r"
(
value
.
w
));
}
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
__device__
__forceinline__
int4
ld_global_nc
(
const
int4
*
ptr
)
{
vec_t
input_vec
;
int4
ret
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
asm
volatile
(
"ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
ret
.
x
),
"=r"
(
ret
.
y
),
"=r"
(
ret
.
z
),
"=r"
(
ret
.
w
)
:
"l"
(
ptr
));
return
ret
;
}
#pragma unroll
template
<
typename
T
>
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
struct
DtypeInfo
;
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
abs_val
=
fabsf
(
val
);
template
<
>
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
struct
DtypeInfo
<
int8_t
>
{
}
static
constexpr
float
MIN
=
-
128
;
static
constexpr
float
MAX
=
127
;
};
template
<
>
struct
DtypeInfo
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
float
MIN
=
-
448
;
static
constexpr
float
MAX
=
448
;
};
template
<
bool
FUSE_SILU_AND_MUL
>
__device__
__forceinline__
int
compute_input_group_start_offset
(
int
expert_idx
,
int
token_idx
,
int
hidden_dim_group_idx
,
int
hidden_size
,
int
num_tokens_per_expert
,
int
group_size
)
{
return
expert_idx
*
num_tokens_per_expert
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
token_idx
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
hidden_dim_group_idx
*
group_size
;
}
constexpr
float
LOCAL_ABSMAX_ABS
=
1e-10
;
constexpr
uint32_t
INPUT_PRIMARY_VEC_NUM_BYTES
=
32
;
struct
NaiveScheduler
{
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
([
=
]()
->
int
{
if
(
num_groups
%
16
==
0
)
{
return
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
return
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
return
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
return
2
;
}
return
1
;
})();
grid
=
dim3
(
num_groups
/
subwarps_per_block
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
}
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
constexpr
int
expert_idx
=
0
;
float
y_s
=
local_absmax
/
max_8bit
;
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
if
constexpr
(
SCALE_UE8M0
)
{
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
y_s
,
1e-10
f
))));
}
// TODO can optimize
const
int64_t
block_group_id
=
blockIdx
.
x
*
subwarps_per_block
;
scale_element_t
y_s_quant
;
const
int64_t
group_id
=
block_group_id
+
subwarp_id
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s_quant
=
(
uint8_t
)(((
int
)
log2f
(
y_s
))
+
127
);
int64_t
input_group_start_offset
;
}
else
{
if
constexpr
(
!
FUSE_SILU_AND_MUL
)
{
y_s_quant
=
y_s
;
input_group_start_offset
=
group_id
*
GROUP_SIZE
;
}
const
int
token_idx
=
group_id
/
hidden_dim_num_groups
;
// At the hidden_size dimension, we are handling idx-th group
const
int
hidden_dim_group_idx
=
group_id
%
hidden_dim_num_groups
;
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
}
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
};
struct
MaskedLayoutScheduler
{
// TODO can be dynamically determined (which may be good when num rank is small)
static
constexpr
int
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
=
1024
;
static
constexpr
int
SUBWARPS_PER_BLOCK
=
16
;
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
SUBWARPS_PER_BLOCK
;
TORCH_CHECK
(
hidden_dim_num_groups
%
subwarps_per_block
==
0
);
grid
=
dim3
(
hidden_dim_num_groups
/
subwarps_per_block
,
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
,
num_local_experts
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
}
if
(
lane_id
==
0
)
{
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
*
scale_output
=
y_s_quant
;
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
const
int
expert_idx
=
blockIdx
.
z
;
const
int
token_idx_start
=
blockIdx
.
y
;
const
int64_t
hidden_dim_group_idx
=
blockIdx
.
x
*
SUBWARPS_PER_BLOCK
+
subwarp_id
;
const
int
curr_expert_token_num
=
masked_m
[
expert_idx
];
for
(
int
token_idx
=
token_idx_start
;
token_idx
<
curr_expert_token_num
;
token_idx
+=
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
const
int64_t
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
}
}
};
template
<
typename
SCHEDULER
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
bool
FUSE_SILU_AND_MUL
=
false
,
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
DST_DTYPE
*
__restrict__
output_q
,
scale_packed_t
*
__restrict__
output_s
,
const
int32_t
*
__restrict__
masked_m
,
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
// TODO can this be removed?
const
int
scale_expert_stride
,
const
int
scale_hidden_stride
,
const
int
num_tokens_per_expert
)
{
using
dst_dtype_info
=
DtypeInfo
<
DST_DTYPE
>
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
SCHEDULER
::
execute
<
FUSE_SILU_AND_MUL
,
GROUP_SIZE
,
THREADS_PER_SUBWARP
>
(
vec_t
input_vec
;
subwarps_per_block
,
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
hidden_dim_num_groups
,
masked_m
,
num_tokens_per_expert
,
[
&
](
const
int
expert_idx
,
const
int
token_idx
,
const
int
hidden_dim_group_idx
,
const
int
lane_id
,
const
int
input_group_start_offset
)
{
constexpr
uint32_t
INPUT_PRIMARY_VEC_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
T
);
constexpr
uint32_t
INPUT_PRIMARY_INT4_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
int4
);
const
int
offset_num_groups
=
expert_idx
*
num_tokens_per_expert
*
hidden_dim_num_groups
+
token_idx
*
hidden_dim_num_groups
+
hidden_dim_group_idx
;
int4
input_primary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_primary_vec
=
reinterpret_cast
<
T
*>
(
input_primary_int4
);
static_assert
(
sizeof
(
input_primary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_primary_int4
));
int4
input_secondary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_secondary_vec
=
reinterpret_cast
<
T
*>
(
input_secondary_int4
);
static_assert
(
sizeof
(
input_secondary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_secondary_int4
));
#pragma unroll
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
input_primary_int4
[
j
]
=
ld_global_nc
(
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
min_8bit
),
max_8bit
);
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
)
+
j
);
group_output
[
i
*
vec_size
+
j
]
=
DST_DTYPE
(
q_val
);
}
}
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
}
const
int
secondary_offset
=
hidden_dim_num_groups
*
GROUP_SIZE
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
input_secondary_int4
[
j
]
=
ld_global_nc
(
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
+
secondary_offset
)
+
j
);
}
}
constexpr
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
scale_element_t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
constexpr
int
scale_token_stride
=
1
;
const
int
hidden_idx_packed
=
hidden_dim_group_idx
/
num_elems_per_pack
;
const
int
pack_idx
=
hidden_dim_group_idx
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
expert_idx
*
scale_expert_stride
*
num_elems_per_pack
+
hidden_idx_packed
*
scale_hidden_stride
*
num_elems_per_pack
+
token_idx
*
scale_token_stride
*
num_elems_per_pack
+
pack_idx
);
}
else
{
static_assert
(
!
SCALE_UE8M0
);
scale_output
=
output_s
+
offset_num_groups
;
}
// can speed up if too slow
if
constexpr
(
IS_COLUMN_MAJOR
and
SCALE_UE8M0
)
{
const
int
remainder_num_groups
=
hidden_dim_num_groups
%
num_elems_per_pack
;
if
((
remainder_num_groups
!=
0
)
and
(
hidden_dim_group_idx
==
hidden_dim_num_groups
-
1
)
and
(
lane_id
<
num_elems_per_pack
-
remainder_num_groups
))
{
const
int
shift
=
1
+
lane_id
;
*
(
scale_output
+
shift
)
=
0
;
}
}
float
local_absmax
=
LOCAL_ABSMAX_ABS
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
float
val
;
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
// TODO maybe vectorize
T
val_lowprec
=
static_cast
<
T
>
(
silu
(
static_cast
<
float
>
(
input_primary_vec
[
j
])))
*
input_secondary_vec
[
j
];
val
=
static_cast
<
float
>
(
val_lowprec
);
input_primary_vec
[
j
]
=
val_lowprec
;
}
else
{
val
=
static_cast
<
float
>
(
input_primary_vec
[
j
]);
}
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
local_absmax
=
GroupReduceMax
<
THREADS_PER_SUBWARP
>
(
local_absmax
,
lane_id
);
float
y_scale
,
y_scale_inv
;
calculate_fp8_scales
<
SCALE_UE8M0
,
dst_dtype_info
>
(
local_absmax
,
y_scale
,
y_scale_inv
);
float2
y_scale_repeated
=
{
y_scale
,
y_scale
};
if
(
lane_id
==
0
)
{
*
scale_output
=
extract_required_scale_format
<
SCALE_UE8M0
>
(
y_scale_inv
);
}
int4
output_buf
;
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
*
sizeof
(
DST_DTYPE
));
if
constexpr
(
std
::
is_same_v
<
DST_DTYPE
,
c10
::
Float8_e4m3fn
>
)
{
const
auto
output_buf_ptr
=
reinterpret_cast
<
__nv_fp8x2_storage_t
*>
(
&
output_buf
);
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
/
2
*
sizeof
(
__nv_fp8x2_storage_t
));
static_assert
(
INPUT_PRIMARY_VEC_SIZE
%
2
==
0
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
j
+=
2
)
{
float2
inputx2
=
{
static_cast
<
float
>
(
input_primary_vec
[
j
]),
static_cast
<
float
>
(
input_primary_vec
[
j
+
1
])};
float2
outputx2
=
fmul2_rn
(
inputx2
,
y_scale_repeated
);
output_buf_ptr
[
j
/
2
]
=
__nv_cvt_float2_to_fp8x2
(
outputx2
,
__NV_SATFINITE
,
__NV_E4M3
);
}
}
else
{
const
auto
output_buf_ptr
=
reinterpret_cast
<
DST_DTYPE
*>
(
&
output_buf
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_primary_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
*
y_scale
,
dst_dtype_info
::
MIN
),
dst_dtype_info
::
MAX
);
output_buf_ptr
[
j
]
=
DST_DTYPE
(
q_val
);
}
}
st_global
(
reinterpret_cast
<
int4
*>
(
output_q
+
offset_num_groups
*
GROUP_SIZE
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
),
output_buf
);
});
}
}
void
sgl_per_token_group_quant_8bit
(
void
sgl_per_token_group_quant_8bit
(
// vanilla: (num_tokens, hidden_size)
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
torch
::
Tensor
output_s
,
...
@@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit(
...
@@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit(
double
eps
,
double
eps
,
double
min_8bit
,
double
min_8bit
,
double
max_8bit
,
double
max_8bit
,
bool
scale_ue8m0
=
false
)
{
bool
scale_ue8m0
,
bool
fuse_silu_and_mul
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
TORCH_CHECK
(
input
.
numel
()
>
0
);
const
int
num_groups
=
input
.
numel
()
/
group_size
;
TORCH_CHECK
(
std
::
abs
(
LOCAL_ABSMAX_ABS
-
eps
)
<
1e-13
)
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
output_s
.
dim
(),
2
);
const
int
num_groups
=
static_cast
<
int
>
(
input
.
numel
())
/
group_size
/
(
fuse_silu_and_mul
?
2
:
1
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
bool
masked_layout
=
masked_m
.
has_value
();
TORCH_CHECK
(
output_s
.
dim
()
==
(
masked_layout
?
3
:
2
));
constexpr
int
THREADS_PER_GROUP
=
16
;
int
groups_per_block
=
1
;
const
int
num_local_experts
=
masked_layout
?
input
.
size
(
0
)
:
1
;
if
(
num_groups
%
16
==
0
)
{
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
auto
dst_type
=
output_q
.
scalar_type
();
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
const
bool
is_column_major
=
output_s
.
stride
(
-
2
)
<
output_s
.
stride
(
-
1
);
const
int
hidden_dim_num_groups
=
static_cast
<
int
>
(
output_q
.
size
(
-
1
))
/
group_size
;
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
const
int
num_tokens_per_expert
=
static_cast
<
int
>
(
output_q
.
size
(
-
2
));
const
int
hidden_dim
=
input
.
size
(
input
.
dim
()
-
1
);
const
int
scale_expert_stride
=
masked_layout
?
static_cast
<
int
>
(
output_s
.
stride
(
0
))
:
0
;
const
int
num_groups_per_row
=
hidden_dim
/
group_size
;
const
int
scale_hidden_stride
=
static_cast
<
int
>
(
output_s
.
stride
(
-
1
));
const
int
scale_stride
=
output_s
.
stride
(
1
);
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
do { \
int subwarps_per_block; \
dim3 grid(num_blocks); \
dim3 grid, block; \
dim3 block(num_threads); \
SCHEDULER::compute_exec_config( \
if (is_column_major) { \
THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
if (scale_ue8m0) { \
\
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
static_cast<T*>(input.data_ptr()), \
<<<grid, block, 0, stream>>>( \
output_q.data_ptr(), \
static_cast<T*>(input.data_ptr()), \
static_cast<uint32_t*>(output_s.data_ptr()), \
static_cast<DST_DTYPE*>(output_q.data_ptr()), \
group_size, \
static_cast<output_s_dtype*>(output_s.data_ptr()), \
num_groups, \
static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
groups_per_block, \
subwarps_per_block, \
(float)eps, \
hidden_dim_num_groups, \
(float)min_8bit, \
scale_expert_stride, \
(float)max_8bit, \
scale_hidden_stride, \
num_groups_per_row, \
num_tokens_per_expert); \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} \
} else { \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0)
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \
do { \
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \
\
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
\
if (is_column_major) { \
if (scale_ue8m0) { \
if (fuse_silu_and_mul) { \
if (masked_layout) { \
LAUNCH_KERNEL_INNER( \
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} else { \
LAUNCH_KERNEL_INNER( \
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \
} \
} while (0)
#define LAUNCH_KERNEL_OUTER(...) \
switch (group_size) { \
case 16: \
LAUNCH_KERNEL(16, __VA_ARGS__); \
break; \
case 32: \
LAUNCH_KERNEL(32, __VA_ARGS__); \
break; \
case 64: \
LAUNCH_KERNEL(64, __VA_ARGS__); \
break; \
case 128: \
LAUNCH_KERNEL(128, __VA_ARGS__); \
break; \
default: \
TORCH_CHECK(false, "Unsupported group_size"); \
} \
while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
int8_t
);
return
true
;
return
true
;
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
(
scalar_t
,
__nv_fp
8_e4m3
);
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
c10
::
Float
8_e4m3
fn
);
return
true
;
return
true
;
}
}
return
false
;
return
false
;
});
});
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL
}
#undef LAUNCH_KERNEL_INNER
void
sgl_per_token_group_quant_int8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
);
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
339f8eef
...
@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
const
torch
::
Dtype
&
out_dtype
);
const
torch
::
Dtype
&
out_dtype
);
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
sgl_per_token_group_quant_
fp8
(
void
sgl_per_token_group_quant_
8bit
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
at
::
Tensor
output_s
,
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
fp8_min
,
double
min_8bit
,
double
fp8_max
,
double
max_8bit
,
bool
scale_ue8m0
);
bool
scale_ue8m0
,
void
sgl_per_token_group_quant_int8
(
bool
fuse_silu_and_mul
,
at
::
Tensor
input
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
);
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
bmm_fp8
(
void
bmm_fp8
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
339f8eef
...
@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
...
@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant
,
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_8bit
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
shuffle_rows
,
silu_and_mul_scaled_fp4_grouped_quant
,
silu_and_mul_scaled_fp4_grouped_quant
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
339f8eef
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return
output
return
output
def
sgl_per_token_group_quant_
fp8
(
def
sgl_per_token_group_quant_
8bit
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
...
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
...
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps
:
float
,
eps
:
float
,
fp8_min
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
fp8_max
:
float
,
scale_ue8m0
:
bool
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_8bit
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
input
,
)
output_q
,
output_s
,
group_size
,
def
sgl_per_token_group_quant_int8
(
eps
,
input
:
torch
.
Tensor
,
fp8_min
,
output_q
:
torch
.
Tensor
,
fp8_max
,
output_s
:
torch
.
Tensor
,
scale_ue8m0
,
group_size
:
int
,
fuse_silu_and_mul
,
eps
:
float
,
masked_m
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
)
...
...
sgl-kernel/python/sgl_kernel/test_utils.py
0 → 100644
View file @
339f8eef
import
torch
def
create_per_token_group_quant_test_data
(
num_tokens
,
hidden_dim
,
num_ranks
,
flags
):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
seed
=
num_tokens
*
10000
+
hidden_dim
gen_cpu
=
torch
.
Generator
(
device
=
"cpu"
)
gen_cpu
.
manual_seed
(
seed
)
gen_cuda
=
torch
.
Generator
(
device
=
"cuda"
)
gen_cuda
.
manual_seed
(
seed
)
if
flags
[
"fuse_silu_and_mul"
]:
effective_hidden_dim
=
hidden_dim
*
2
else
:
effective_hidden_dim
=
hidden_dim
del
hidden_dim
if
(
masked_layout_mode
:
=
flags
[
"masked_layout_mode"
])
is
not
None
:
num_max_dispatch_tokens_per_rank
=
768
num_global_experts
=
288
num_local_experts
,
remainder
=
divmod
(
num_global_experts
,
num_ranks
)
assert
remainder
==
0
# mimic DeepEP low_latency_dispatch output
x
=
torch
.
randn
(
num_local_experts
,
num_max_dispatch_tokens_per_rank
*
num_ranks
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
if
masked_layout_mode
==
"balanced"
:
masked_m
=
_compute_balanced_split
(
num_tokens
,
num_local_experts
)
elif
masked_layout_mode
==
"imbalanced"
:
masked_m
=
_compute_imbalanced_split
(
num_tokens
,
num_local_experts
,
gen_cpu
=
gen_cpu
)
elif
masked_layout_mode
==
"extreme"
:
masked_m
=
torch
.
tensor
(
[
num_tokens
]
+
[
0
]
*
(
num_local_experts
-
1
),
dtype
=
torch
.
int
)
else
:
raise
NotImplementedError
print
(
f
"
{
masked_layout_mode
=
}
{
masked_m
=
}
{
x
.
shape
=
}
"
)
masked_m
=
masked_m
.
to
(
device
)
return
x
,
masked_m
else
:
x
=
torch
.
randn
(
num_tokens
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
x
[
torch
.
randn
(
x
.
shape
,
device
=
device
,
generator
=
gen_cuda
)
<
0.001
]
*=
10
return
x
,
None
def
_compute_balanced_split
(
total
:
int
,
arr_len
:
int
):
base
=
total
//
arr_len
remainder
=
total
%
arr_len
ans
=
[
base
+
1
if
i
<
remainder
else
base
for
i
in
range
(
arr_len
)]
assert
sum
(
ans
)
==
total
return
torch
.
tensor
(
ans
,
dtype
=
torch
.
int
)
def
_compute_imbalanced_split
(
total
:
int
,
arr_len
:
int
,
gen_cpu
,
dtype
=
torch
.
int
)
->
list
[
int
]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw
=
torch
.
rand
(
arr_len
,
generator
=
gen_cpu
)
**
3
noise
=
noise_raw
/
noise_raw
.
sum
()
ans
=
(
noise
*
total
).
round
().
to
(
dtype
)
diff
=
total
-
ans
.
sum
().
item
()
while
diff
!=
0
:
idx
=
torch
.
randint
(
0
,
arr_len
,
(
1
,),
generator
=
gen_cpu
).
item
()
if
diff
>
0
:
ans
[
idx
]
+=
1
diff
-=
1
elif
diff
<
0
and
ans
[
idx
]
>
0
:
ans
[
idx
]
-=
1
diff
+=
1
assert
sum
(
ans
)
==
total
return
ans
def
assert_all_close_or_tiny_diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
(
a
.
shape
==
b
.
shape
)
and
(
a
.
dtype
==
b
.
dtype
),
f
"
{
a
.
shape
=
}
{
b
.
shape
=
}
{
a
.
dtype
=
}
{
b
.
dtype
=
}
"
numel
=
a
.
numel
()
if
a
.
dtype
==
torch
.
float8_e4m3fn
:
a_u8
=
a
.
view
(
torch
.
uint8
)
b_u8
=
b
.
view
(
torch
.
uint8
)
diff_u8
=
(
a_u8
.
to
(
torch
.
int16
)
-
b_u8
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a_u8
>=
0
)
&
(
b_u8
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff_u8
==
1
).
sum
().
item
()
count_large_diff
=
(
diff_u8
>=
2
).
sum
().
item
()
elif
a
.
dtype
==
torch
.
int8
:
diff
=
(
a
.
to
(
torch
.
int16
)
-
a
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a
>=
0
)
&
(
b
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff
==
1
).
sum
().
item
()
count_large_diff
=
(
diff
>=
2
).
sum
().
item
()
else
:
raise
NotImplementedError
assert
(
(
count_diff_sign
==
0
)
and
(
count_large_diff
==
0
)
and
(
(
count_tiny_diff
/
numel
<
0.005
)
or
((
count_tiny_diff
/
numel
<
0.04
)
and
(
numel
<=
4096
))
)
),
f
"
{
count_diff_sign
=
}
{
count_tiny_diff
=
}
{
count_large_diff
=
}
{
numel
=
}
{
a
=
}
{
b
=
}
"
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
itertools
import
os
import
time
from
pathlib
import
Path
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel.test_utils
import
(
assert_all_close_or_tiny_diff
,
create_per_token_group_quant_test_data
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.utils
import
assert_fp8_all_close
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
128
,
256
,
384
,
512
,
1024
,
1536
,
1664
,
2048
,
4096
,
7168
,
16384
],
# hidden_dim
[
16
,
32
,
64
,
128
],
# group_size
[
None
],
# num_ranks
[
fp8_type_
,
torch
.
int8
],
# dtype
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
1
,
4
,
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
# TODO support more
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, dst_dtype, flags"
,
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags"
,
configs
list
(
itertools
.
product
(
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
8
,
16
,
32
,
64
,
128
],
# group_size
# TODO test int8
[
fp8_type_
],
# dtype
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
),
],
)
),
)
)
def
test_per_token_group_quant_with_column_major
(
def
test_per_token_group_quant_with_column_major
(
num_tokens
,
num_tokens
,
hidden_dim
,
hidden_dim
,
group_size
,
group_size
,
num_ranks
,
dst_dtype
,
dst_dtype
,
flags
,
flags
,
):
):
if
flags
[
"scale_ue8m0"
]
and
((
group_size
!=
128
)
or
(
hidden_dim
%
512
!=
0
)):
print
(
pytest
.
skip
()
f
"
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
"
)
arch_major
,
_
=
torch
.
cuda
.
get_device_capability
(
torch
.
cuda
.
current_device
())
if
flags
[
"scale_ue8m0"
]
and
(
arch_major
<=
9
):
pytest
.
skip
(
"Only Blackwell need ue8m0 fusion"
)
return
return
if
flags
[
"scale_ue8m0"
]
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
:
pytest
.
skip
(
"scale_ue8m0 only supported on Blackwell"
)
if
(
flags
[
"scale_ue8m0"
]
and
(
group_size
!=
128
))
or
(
(
dst_dtype
==
torch
.
int8
)
and
flags
[
"column_major_scales"
]
):
pytest
.
skip
()
return
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs
=
dict
(
execute_kwargs
=
dict
(
x
=
x
,
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
group_size
=
group_size
,
eps
=
1e-10
,
eps
=
1e-10
,
dst_dtype
=
dst_dtype
,
dst_dtype
=
dst_dtype
,
**
flags
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]}
,
)
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
def
_postprocess
(
x_q
,
x_s
):
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
if
masked_m
is
not
None
:
print
(
f
"Mask tokens after
{
masked_m
}
to be zero"
)
# torch.set_printoptions(profile="full")
for
i
in
range
(
len
(
masked_m
)):
# print(f"{x_q_triton=}")
x_q
[
i
,
masked_m
[
i
]
:,
:]
=
0
# print(f"{x_s_triton=}")
x_s
[
i
,
masked_m
[
i
]
:,
:]
=
0
# print(f"{x_q_sglang=}")
return
x_q
,
x_s
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
x_q_triton
,
x_s_triton
=
_postprocess
(
*
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
assert_fp8_all_close
(
x_q_triton
,
x_q_sglang
)
)
torch
.
testing
.
assert_close
(
x_q_sglang
,
x_s_sglang
=
_postprocess
(
x_s_triton
.
contiguous
(),
*
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
)
try
:
assert_all_close_or_tiny_diff
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
except
AssertionError
:
# torch.set_printoptions(profile="full")
print
(
f
"
{
x
.
shape
=
}
{
x_q_triton
.
shape
=
}
{
x_s_triton
.
shape
=
}
{
x_q_sglang
.
shape
=
}
{
x_s_sglang
.
shape
=
}
"
)
print
(
f
"
{
x
=
}
"
)
print
(
f
"
{
masked_m
=
}
"
)
print
(
f
"
{
x_q_triton
=
}
"
)
print
(
f
"
{
x_s_triton
=
}
"
)
print
(
f
"
{
x_q_sglang
=
}
"
)
print
(
f
"
{
x_s_sglang
=
}
"
)
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment