Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6d55f60e
Unverified
Commit
6d55f60e
authored
Sep 10, 2025
by
Yineng Zhang
Committed by
GitHub
Sep 10, 2025
Browse files
Revert "[1/2] Optimizations and refactors about quant kernel (#9534)" (#10292)
parent
033b75f5
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
335 additions
and
1002 deletions
+335
-1002
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+2
-4
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+8
-29
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+2
-8
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+48
-203
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+8
-3
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+177
-447
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+12
-6
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+2
-1
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+18
-15
sgl-kernel/python/sgl_kernel/test_utils.py
sgl-kernel/python/sgl_kernel/test_utils.py
+0
-125
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+58
-161
No files found.
python/sglang/srt/bench_utils.py
View file @
6d55f60e
import
os
import
os
import
re
import
sys
import
sys
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
...
@@ -109,8 +108,7 @@ def bench_kineto(
...
@@ -109,8 +108,7 @@ def bench_kineto(
if
not
with_multiple_kernels
:
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
for
name
in
kernel_names
:
assert
(
assert
(
sum
([
int
(
re
.
search
(
name
,
line
)
is
not
None
)
for
line
in
prof_lines
])
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
# Save chrome traces
...
@@ -124,7 +122,7 @@ def bench_kineto(
...
@@ -124,7 +122,7 @@ def bench_kineto(
total_time
=
0
total_time
=
0
total_num
=
0
total_num
=
0
for
line
in
prof_lines
:
for
line
in
prof_lines
:
if
re
.
search
(
name
,
line
)
is
not
No
ne
:
if
name
in
li
ne
:
time_str
=
line
.
split
()[
-
2
]
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
for
unit
,
scale
in
units
.
items
():
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
6d55f60e
...
@@ -43,17 +43,11 @@ _is_cpu = is_cpu()
...
@@ -43,17 +43,11 @@ _is_cpu = is_cpu()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
from
sgl_kernel
import
(
sgl_per_tensor_quant_fp8
,
# Temporary
sgl_per_token_group_quant_fp8
,
try
:
sgl_per_token_quant_fp8
,
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
)
enable_sgl_per_token_group_quant_8bit
=
True
except
ImportError
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit
=
False
if
_is_hip
:
if
_is_hip
:
if
_use_aiter
:
if
_use_aiter
:
...
@@ -502,21 +496,6 @@ def sglang_per_token_group_quant_fp8(
...
@@ -502,21 +496,6 @@ def sglang_per_token_group_quant_fp8(
)
)
if
x
.
shape
[
0
]
>
0
:
if
x
.
shape
[
0
]
>
0
:
# Temporary
if
enable_sgl_per_token_group_quant_8bit
:
sgl_per_token_group_quant_8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
,
fuse_silu_and_mul
,
masked_m
,
)
else
:
sgl_per_token_group_quant_fp8
(
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
)
...
...
python/sglang/srt/layers/quantization/int8_kernel.py
View file @
6d55f60e
...
@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda
...
@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
# Temporary
from
sgl_kernel
import
sgl_per_token_group_quant_int8
try
:
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
except
ImportError
:
from
sgl_kernel
import
(
sgl_per_token_group_quant_int8
as
sgl_per_token_group_quant_8bit
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8(
...
@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8(
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
sgl_per_token_group_quant_
8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
sgl_per_token_group_quant_
int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
6d55f60e
import
itertools
import
itertools
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
import
triton
import
triton
from
sgl_kernel.test_utils
import
create_per_token_group_quant_test_data
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
@@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip
...
@@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
mode_concentrated
=
os
.
environ
.
get
(
"SGLANG_BENCH_MODE"
,
""
)
==
"concentrated"
if
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
"0"
)):
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
# configs = [[
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
# 768,
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# 16384,
# TODO test int8
# 128,
dst_dtype_range
=
[
fp8_type_
]
# None,
flags_range
=
[
# fp8_type_,
# dict(
# column_major_scales=True,
# scale_tma_aligned=True,
# scale_ue8m0=True,
# fuse_silu_and_mul=False,
# masked_layout_mode=None,
# ),
# ]]
configs
=
[
[
768
*
8
,
2048
,
128
,
48
,
fp8_type_
,
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
# masked_layout_mode=None,
masked_layout_mode
=
"balanced"
,
# masked_layout_mode="extreme",
),
]
]
elif
mode_concentrated
:
configs
=
list
(
itertools
.
product
(
[
768
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
768
*
8
],
[
2048
],
[
128
],
[
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
else
:
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
dict
(
column_major_scales
=
False
,
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
],
]
)
)
+
list
(
configs
=
list
(
itertools
.
product
(
itertools
.
product
(
[
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
num_tokens_range
,
[
2048
],
hidden_dim_range
,
[
128
],
group_size_range
,
[
8
,
16
,
32
,
48
],
dst_dtype_range
,
[
fp8_type_
],
flags_range
,
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
)
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"dst_dtype"
,
"flags"
],
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"num_ranks"
,
"dst_dtype"
,
"flags"
,
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_vals
=
[
"triton"
,
"sglang"
],
# Triton has multi kernels and we only report the time for the core one
line_names
=
[
"Triton"
,
"SGL Kernel"
],
line_names
=
[
"Triton (Inaccurate)"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
flags
,
provider
):
num_tokens
,
hidden_dim
,
group_size
,
num_ranks
,
dst_dtype
,
flags
,
provider
if
flags
[
"scale_ue8m0"
]
and
group_size
!=
128
:
):
return
print
(
f
"Testing:
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
{
provider
=
}
"
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
device
=
torch
.
device
(
"cuda"
)
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
fn
,
kernel_names
=
{
fn
,
kernel_names
=
{
"triton"
:
(
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_fp8"
),
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel"
,
),
"sglang"
:
(
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
"per_token_group_quant_8bit_kernel"
,
),
),
}[
provider
]
}[
provider
]
bench_fn
=
lambda
:
fn
(
bench_fn
=
lambda
:
fn
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
flags
)
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]},
)
time_s
=
bench_kineto
(
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
)
bench_fn
,
kernel_names
=
kernel_names
,
num_tests
=
300
if
mode_concentrated
else
30
)
return
time_s
*
1e6
return
time_s
*
1e6
...
...
sgl-kernel/csrc/common_extension.cc
View file @
6d55f60e
...
@@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
def
(
m
.
def
(
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"
);
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_8bit"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_8bit
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
6d55f60e
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c
10/util/Float8_e4m3fn
.h>
#include <c
uda_fp8
.h>
#include <cmath>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#include "utils.h"
template
<
int
THREADS_PER_SUBWARP
>
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
threadIdx
.
x
%
32
>=
16
?
0xffff0000
:
0x0000ffff
;
unsigned
mask
=
threadIdx
.
x
%
32
>=
16
?
0xffff0000
:
0x0000ffff
;
static_assert
(
(
THREADS_PER_SUBWARP
&
(
THREADS_PER_SUBWARP
-
1
))
==
0
&&
THREADS_PER_SUBWARP
<=
16
&&
THREADS_PER_SUBWARP
>=
1
,
"THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16"
);
if
constexpr
(
THREADS_PER_SUBWARP
>=
16
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
8
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
4
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
}
if
constexpr
(
THREADS_PER_SUBWARP
>=
2
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
}
return
val
;
return
val
;
}
}
__device__
__forceinline__
float
silu
(
const
float
&
val
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float
half
=
0.5
f
*
val
;
float
t
=
__tanhf
(
half
);
return
half
*
(
1.0
f
+
t
);
#else
return
val
/
(
1.0
f
+
__expf
(
-
val
));
#endif
}
__device__
float2
fmul2_rn
(
float2
a
,
float2
b
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return
__fmul2_rn
(
a
,
b
);
#else
float2
result
;
result
.
x
=
a
.
x
*
b
.
x
;
result
.
y
=
a
.
y
*
b
.
y
;
return
result
;
#endif
}
// Copied and modified from DeepEP
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
// Copied and modified from DeepEP
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
// Copied and modified from DeepEP
template
<
bool
ROUND_SCALE
,
typename
dtype_info
>
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
constexpr
float
MAX_8BIT_INV
=
1.0
f
/
dtype_info
::
MAX
;
if
constexpr
(
ROUND_SCALE
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
MAX_8BIT_INV
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
MAX_8BIT_INV
;
scale
=
dtype_info
::
MAX
/
amax
;
}
}
// Copied and modified from DeepEP
template
<
bool
SCALE_UE8M0
,
typename
OUT_DTYPE_T
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
OUT_DTYPE_T
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
SCALE_UE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
__device__
__forceinline__
void
st_global
(
const
int4
*
ptr
,
const
int4
&
value
)
{
asm
volatile
(
"st.global.v4.s32 [%0], {%1, %2, %3, %4};"
::
"l"
(
ptr
),
"r"
(
value
.
x
),
"r"
(
value
.
y
),
"r"
(
value
.
z
),
"r"
(
value
.
w
));
}
__device__
__forceinline__
int4
ld_global_nc
(
const
int4
*
ptr
)
{
int4
ret
;
asm
volatile
(
"ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
ret
.
x
),
"=r"
(
ret
.
y
),
"=r"
(
ret
.
z
),
"=r"
(
ret
.
w
)
:
"l"
(
ptr
));
return
ret
;
}
template
<
typename
T
>
struct
DtypeInfo
;
template
<
>
struct
DtypeInfo
<
int8_t
>
{
static
constexpr
float
MIN
=
-
128
;
static
constexpr
float
MAX
=
127
;
};
template
<
>
struct
DtypeInfo
<
c10
::
Float8_e4m3fn
>
{
static
constexpr
float
MIN
=
-
448
;
static
constexpr
float
MAX
=
448
;
};
template
<
bool
FUSE_SILU_AND_MUL
>
__device__
__forceinline__
int
compute_input_group_start_offset
(
int
expert_idx
,
int
token_idx
,
int
hidden_dim_group_idx
,
int
hidden_size
,
int
num_tokens_per_expert
,
int
group_size
)
{
return
expert_idx
*
num_tokens_per_expert
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
token_idx
*
hidden_size
*
(
FUSE_SILU_AND_MUL
?
2
:
1
)
+
hidden_dim_group_idx
*
group_size
;
}
constexpr
float
LOCAL_ABSMAX_ABS
=
1e-10
;
constexpr
uint32_t
INPUT_PRIMARY_VEC_NUM_BYTES
=
32
;
struct
NaiveScheduler
{
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
([
=
]()
->
int
{
if
(
num_groups
%
16
==
0
)
{
return
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
return
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
return
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
return
2
;
}
return
1
;
})();
grid
=
dim3
(
num_groups
/
subwarps_per_block
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
constexpr
int
expert_idx
=
0
;
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
subwarps_per_block
;
const
int64_t
group_id
=
block_group_id
+
subwarp_id
;
int64_t
input_group_start_offset
;
if
constexpr
(
!
FUSE_SILU_AND_MUL
)
{
input_group_start_offset
=
group_id
*
GROUP_SIZE
;
}
const
int
token_idx
=
group_id
/
hidden_dim_num_groups
;
// At the hidden_size dimension, we are handling idx-th group
const
int
hidden_dim_group_idx
=
group_id
%
hidden_dim_num_groups
;
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
}
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
};
struct
MaskedLayoutScheduler
{
// TODO can be dynamically determined (which may be good when num rank is small)
static
constexpr
int
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
=
1024
;
static
constexpr
int
SUBWARPS_PER_BLOCK
=
16
;
static
void
compute_exec_config
(
int
threads_per_subwarp
,
int
num_local_experts
,
int
hidden_dim_num_groups
,
int
num_groups
,
int
&
subwarps_per_block
,
dim3
&
grid
,
dim3
&
block
)
{
subwarps_per_block
=
SUBWARPS_PER_BLOCK
;
TORCH_CHECK
(
hidden_dim_num_groups
%
subwarps_per_block
==
0
);
grid
=
dim3
(
hidden_dim_num_groups
/
subwarps_per_block
,
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
,
num_local_experts
);
block
=
dim3
(
subwarps_per_block
*
threads_per_subwarp
);
}
template
<
bool
FUSE_SILU_AND_MUL
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
FUNC
>
__device__
__forceinline__
static
void
execute
(
const
int
subwarps_per_block
,
const
int
hidden_dim_num_groups
,
const
int32_t
*
masked_m
,
const
int
num_tokens_per_expert
,
FUNC
fn
)
{
const
int64_t
subwarp_id
=
threadIdx
.
x
/
THREADS_PER_SUBWARP
;
const
int
lane_id
=
threadIdx
.
x
%
THREADS_PER_SUBWARP
;
const
int
expert_idx
=
blockIdx
.
z
;
const
int
token_idx_start
=
blockIdx
.
y
;
const
int64_t
hidden_dim_group_idx
=
blockIdx
.
x
*
SUBWARPS_PER_BLOCK
+
subwarp_id
;
const
int
curr_expert_token_num
=
masked_m
[
expert_idx
];
for
(
int
token_idx
=
token_idx_start
;
token_idx
<
curr_expert_token_num
;
token_idx
+=
TOKEN_DIM_BLOCK_NUM_PER_EXPERT
)
{
const
int
hidden_size
=
hidden_dim_num_groups
*
GROUP_SIZE
;
const
int64_t
input_group_start_offset
=
compute_input_group_start_offset
<
FUSE_SILU_AND_MUL
>
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
hidden_size
,
num_tokens_per_expert
,
GROUP_SIZE
);
fn
(
expert_idx
,
token_idx
,
hidden_dim_group_idx
,
lane_id
,
input_group_start_offset
);
}
}
};
template
<
template
<
typename
SCHEDULER
,
int
GROUP_SIZE
,
int
THREADS_PER_SUBWARP
,
typename
T
,
typename
T
,
typename
DST_DTYPE
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
bool
SCALE_UE8M0
=
false
,
bool
FUSE_SILU_AND_MUL
=
false
,
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
typename
scale_packed_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint32_t
,
float
>
>
__global__
void
per_token_group_quant_8bit_kernel
(
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
DST_DTYPE
*
__restrict__
output_q
,
void
*
__restrict__
output_q
,
scale_packed_t
*
__restrict__
output_s
,
scale_packed_t
*
__restrict__
output_s
,
const
int32_t
*
__restrict__
masked_m
,
const
int
group_size
,
const
int
subwarps_per_block
,
const
int
num_groups
,
const
int
hidden_dim_num_groups
,
const
int
groups_per_block
,
// TODO can this be removed?
const
float
eps
,
const
int
scale_expert_stride
,
const
float
min_8bit
,
const
int
scale_hidden_stride
,
const
float
max_8bit
,
const
int
num_tokens_per_expert
)
{
const
int
num_groups_per_row
=
0
,
using
dst_dtype_info
=
DtypeInfo
<
DST_DTYPE
>
;
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int64_t
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
using
scale_element_t
=
std
::
conditional_t
<
SCALE_UE8M0
,
uint8_t
,
float
>
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
SCHEDULER
::
execute
<
FUSE_SILU_AND_MUL
,
GROUP_SIZE
,
THREADS_PER_SUBWARP
>
(
const
T
*
group_input
=
input
+
block_group_offset
;
subwarps_per_block
,
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
hidden_dim_num_groups
,
masked_m
,
num_tokens_per_expert
,
[
&
](
const
int
expert_idx
,
const
int
token_idx
,
const
int
hidden_dim_group_idx
,
const
int
lane_id
,
const
int
input_group_start_offset
)
{
constexpr
uint32_t
INPUT_PRIMARY_VEC_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
T
);
constexpr
uint32_t
INPUT_PRIMARY_INT4_SIZE
=
INPUT_PRIMARY_VEC_NUM_BYTES
/
sizeof
(
int4
);
const
int
offset_num_groups
=
expert_idx
*
num_tokens_per_expert
*
hidden_dim_num_groups
+
token_idx
*
hidden_dim_num_groups
+
hidden_dim_group_idx
;
int4
input_primary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_primary_vec
=
reinterpret_cast
<
T
*>
(
input_primary_int4
);
static_assert
(
sizeof
(
input_primary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_primary_int4
));
int4
input_secondary_int4
[
INPUT_PRIMARY_INT4_SIZE
];
T
*
input_secondary_vec
=
reinterpret_cast
<
T
*>
(
input_secondary_int4
);
static_assert
(
sizeof
(
input_secondary_vec
[
0
])
*
INPUT_PRIMARY_VEC_SIZE
==
sizeof
(
input_secondary_int4
));
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
input_primary_int4
[
j
]
=
ld_global_nc
(
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
)
+
j
);
}
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
const
int
secondary_offset
=
hidden_dim_num_groups
*
GROUP_SIZE
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_INT4_SIZE
;
++
j
)
{
input_secondary_int4
[
j
]
=
ld_global_nc
(
reinterpret_cast
<
const
int4
*>
(
input
+
input_group_start_offset
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
+
secondary_offset
)
+
j
);
}
}
constexpr
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
scale_element_t
*
scale_output
;
scale_element_t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
constexpr
int
scale_token_stride
=
1
;
const
int
hidden_idx_packed
=
hidden_dim_group_idx
/
num_elems_per_pack
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
pack_idx
=
hidden_dim_group_idx
%
num_elems_per_pack
;
const
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
const
int
row_idx
=
global_group_id
/
num_groups_per_row
;
const
int
col_idx_unpacked
=
global_group_id
%
num_groups_per_row
;
const
int
col_idx
=
col_idx_unpacked
/
num_elems_per_pack
;
const
int
pack_idx
=
col_idx_unpacked
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
expert_idx
*
scale_expert_stride
*
num_elems_per_pack
+
(
col_idx
*
scale_stride
*
num_elems_per_pack
+
row_idx
*
num_elems_per_pack
+
pack_idx
);
hidden_idx_packed
*
scale_hidden_stride
*
num_elems_per_pack
+
token_idx
*
scale_token_stride
*
num_elems_per_pack
+
pack_idx
);
}
else
{
}
else
{
static_assert
(
!
SCALE_UE8M0
);
static_assert
(
!
SCALE_UE8M0
);
scale_output
=
output_s
+
offset_num
_group
s
;
scale_output
=
output_s
+
global
_group
_id
;
}
}
// can speed up if too slow
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
if
constexpr
(
IS_COLUMN_MAJOR
and
SCALE_UE8M0
)
{
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int
remainder_num_groups
=
hidden_dim_num_groups
%
num_elems_per_pack
;
if
((
remainder_num_groups
!=
0
)
and
(
hidden_dim_group_idx
==
hidden_dim_num_groups
-
1
)
and
(
lane_id
<
num_elems_per_pack
-
remainder_num_groups
))
{
const
int
shift
=
1
+
lane_id
;
*
(
scale_output
+
shift
)
=
0
;
}
}
float
local_absmax
=
LOCAL_ABSMAX_ABS
;
const
int32_t
num_vec_elems
=
group_size
/
vec_size
;
#pragma unroll
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
vec_t
input_vec
;
float
val
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
if
constexpr
(
FUSE_SILU_AND_MUL
)
{
// TODO maybe vectorize
T
val_lowprec
=
static_cast
<
T
>
(
silu
(
static_cast
<
float
>
(
input_primary_vec
[
j
])))
*
input_secondary_vec
[
j
];
val
=
static_cast
<
float
>
(
val_lowprec
);
input_primary_vec
[
j
]
=
val_lowprec
;
}
else
{
val
=
static_cast
<
float
>
(
input_primary_vec
[
j
]);
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
abs_val
=
fabsf
(
val
);
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
}
}
local_absmax
=
GroupReduceMax
<
THREADS_PER_SUBWARP
>
(
local_absmax
,
lane_id
);
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
float
y_scale
,
y_scale_inv
;
calculate_fp8_scales
<
SCALE_UE8M0
,
dst_dtype_info
>
(
local_absmax
,
y_scale
,
y_scale_inv
);
float2
y_scale_repeated
=
{
y_scale
,
y_scale
};
if
(
lane_id
==
0
)
{
float
y_s
=
local_absmax
/
max_8bit
;
*
scale_output
=
extract_required_scale_format
<
SCALE_UE8M0
>
(
y_scale_inv
);
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
y_s
,
1e-10
f
))));
}
}
int4
output_buf
;
// TODO can optimize
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
*
sizeof
(
DST_DTYPE
))
;
scale_element_t
y_s_quant
;
if
constexpr
(
SCALE_UE8M0
)
{
if
constexpr
(
std
::
is_same_v
<
DST_DTYPE
,
c10
::
Float8_e4m3fn
>
)
{
y_s_quant
=
(
uint8_t
)(((
int
)
log2f
(
y_s
))
+
127
);
const
auto
output_buf_ptr
=
reinterpret_cast
<
__nv_fp8x2_storage_t
*>
(
&
output_buf
);
}
else
{
static_assert
(
sizeof
(
output_buf
)
==
INPUT_PRIMARY_VEC_SIZE
/
2
*
sizeof
(
__nv_fp8x2_storage_t
))
;
y_s_quant
=
y_s
;
static_assert
(
INPUT_PRIMARY_VEC_SIZE
%
2
==
0
);
}
#pragma unroll
if
(
lane_id
==
0
)
{
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
j
+=
2
)
{
*
scale_output
=
y_s_quant
;
float2
inputx2
=
{
static_cast
<
float
>
(
input_primary_vec
[
j
]),
static_cast
<
float
>
(
input_primary_vec
[
j
+
1
])};
float2
outputx2
=
fmul2_rn
(
inputx2
,
y_scale_repeated
);
output_buf_ptr
[
j
/
2
]
=
__nv_cvt_float2_to_fp8x2
(
outputx2
,
__NV_SATFINITE
,
__NV_E4M3
);
}
}
}
else
{
const
auto
output_buf_ptr
=
reinterpret_cast
<
DST_DTYPE
*>
(
&
output_buf
);
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
#pragma unroll
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
INPUT_PRIMARY_VEC_SIZE
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_
primary_
vec
[
j
]);
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
*
y_s
cale
,
dst_dtype_info
::
MIN
),
dst_dtype_info
::
MAX
);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
min_8bit
),
max_8bit
);
output_buf_ptr
[
j
]
=
DST_DTYPE
(
q_val
);
group_output
[
i
*
vec_size
+
j
]
=
DST_DTYPE
(
q_val
);
}
}
}
}
st_global
(
reinterpret_cast
<
int4
*>
(
output_q
+
offset_num_groups
*
GROUP_SIZE
+
lane_id
*
INPUT_PRIMARY_VEC_SIZE
),
output_buf
);
});
}
}
void
sgl_per_token_group_quant_8bit
(
void
sgl_per_token_group_quant_8bit
(
// vanilla: (num_tokens, hidden_size)
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
torch
::
Tensor
output_s
,
...
@@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit(
...
@@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit(
double
eps
,
double
eps
,
double
min_8bit
,
double
min_8bit
,
double
max_8bit
,
double
max_8bit
,
bool
scale_ue8m0
,
bool
scale_ue8m0
=
false
)
{
bool
fuse_silu_and_mul
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
TORCH_CHECK
(
input
.
numel
()
>
0
);
TORCH_CHECK
(
std
::
abs
(
LOCAL_ABSMAX_ABS
-
eps
)
<
1e-13
)
;
const
int
num_groups
=
input
.
numel
()
/
group_size
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
const
int
num_groups
=
static_cast
<
int
>
(
input
.
numel
())
/
group_size
/
(
fuse_silu_and_mul
?
2
:
1
);
CHECK_EQ
(
output_s
.
dim
(),
2
);
const
bool
masked_layout
=
masked_m
.
has_value
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
output_s
.
dim
()
==
(
masked_layout
?
3
:
2
));
const
int
num_local_experts
=
masked_layout
?
input
.
size
(
0
)
:
1
;
const
expr
int
THREADS_PER_GROUP
=
16
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
()
;
int
groups_per_block
=
1
;
auto
dst_type
=
output_q
.
scalar_type
();
if
(
num_groups
%
16
==
0
)
{
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
const
bool
is_column_major
=
output_s
.
stride
(
-
2
)
<
output_s
.
stride
(
-
1
);
auto
dst_type
=
output_q
.
scalar_type
();
const
int
hidden_dim_num_groups
=
static_cast
<
int
>
(
output_q
.
size
(
-
1
))
/
group_size
;
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_tokens_per_expert
=
static_cast
<
int
>
(
output_q
.
size
(
-
2
));
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
const
int
scale_expert_stride
=
masked_layout
?
static_cast
<
int
>
(
output_s
.
stride
(
0
))
:
0
;
const
int
scale_hidden_stride
=
static_cast
<
int
>
(
output_s
.
stride
(
-
1
));
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
do { \
const
int
hidden_dim
=
input
.
size
(
input
.
dim
()
-
1
);
int subwarps_per_block; \
const
int
num_groups_per_row
=
hidden_dim
/
group_size
;
dim3 grid, block; \
const
int
scale_stride
=
output_s
.
stride
(
1
);
SCHEDULER::compute_exec_config( \
THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
\
per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
<<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
static_cast<DST_DTYPE*>(output_q.data_ptr()), \
static_cast<output_s_dtype*>(output_s.data_ptr()), \
static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
subwarps_per_block, \
hidden_dim_num_groups, \
scale_expert_stride, \
scale_hidden_stride, \
num_tokens_per_expert); \
} while (0)
#define LAUNCH_KERNEL(
GROUP_SIZE,
T, DST_DTYPE)
\
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
do { \
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \
dim3 grid(num_blocks); \
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \
dim3 block(num_threads); \
\
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
\
if (is_column_major) { \
if (is_column_major) { \
if (scale_ue8m0) { \
if (scale_ue8m0) { \
if (fuse_silu_and_mul) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
if (masked_layout) { \
static_cast<T*>(input.data_ptr()), \
LAUNCH_KERNEL_INNER( \
output_q.data_ptr(), \
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
static_cast<uint32_t*>(output_s.data_ptr()), \
} else { \
group_size, \
LAUNCH_KERNEL_INNER( \
num_groups, \
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
groups_per_block, \
} \
(float)eps, \
} else { \
(float)min_8bit, \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \
(float)max_8bit, \
} \
num_groups_per_row, \
scale_stride); \
} else { \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} \
} \
} else { \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)
} while (0)
#define LAUNCH_KERNEL_OUTER(...) \
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
switch (group_size) { \
case 16: \
LAUNCH_KERNEL(16, __VA_ARGS__); \
break; \
case 32: \
LAUNCH_KERNEL(32, __VA_ARGS__); \
break; \
case 64: \
LAUNCH_KERNEL(64, __VA_ARGS__); \
break; \
case 128: \
LAUNCH_KERNEL(128, __VA_ARGS__); \
break; \
default: \
TORCH_CHECK(false, "Unsupported group_size"); \
} \
while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
int8_t
);
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
return
true
;
return
true
;
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
_OUTER
(
scalar_t
,
c10
::
Float
8_e4m3
fn
);
LAUNCH_KERNEL
(
scalar_t
,
__nv_fp
8_e4m3
);
return
true
;
return
true
;
}
}
return
false
;
return
false
;
});
});
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_INNER
}
void
sgl_per_token_group_quant_int8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
)
{
sgl_per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
);
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
6d55f60e
...
@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
const
torch
::
Dtype
&
out_dtype
);
const
torch
::
Dtype
&
out_dtype
);
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
sgl_per_token_group_quant_
8bit
(
void
sgl_per_token_group_quant_
fp8
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
at
::
Tensor
output_s
,
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
min_8bit
,
double
fp8_min
,
double
max_8bit
,
double
fp8_max
,
bool
scale_ue8m0
,
bool
scale_ue8m0
);
bool
fuse_silu_and_mul
,
void
sgl_per_token_group_quant_int8
(
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
);
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
bmm_fp8
(
void
bmm_fp8
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
6d55f60e
...
@@ -58,7 +58,8 @@ from sgl_kernel.gemm import (
...
@@ -58,7 +58,8 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant
,
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_8bit
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
shuffle_rows
,
silu_and_mul_scaled_fp4_grouped_quant
,
silu_and_mul_scaled_fp4_grouped_quant
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
6d55f60e
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return
output
return
output
def
sgl_per_token_group_quant_
8bit
(
def
sgl_per_token_group_quant_
fp8
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
...
@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
...
@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
eps
:
float
,
eps
:
float
,
fp8_min
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
fp8_max
:
float
,
scale_ue8m0
:
bool
=
False
,
scale_ue8m0
:
bool
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_8bit
.
default
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
output_q
,
)
output_s
,
group_size
,
eps
,
def
sgl_per_token_group_quant_int8
(
fp8_min
,
input
:
torch
.
Tensor
,
fp8_max
,
output_q
:
torch
.
Tensor
,
scale_ue8m0
,
output_s
:
torch
.
Tensor
,
fuse_silu_and_mul
,
group_size
:
int
,
masked_m
,
eps
:
float
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
)
...
...
sgl-kernel/python/sgl_kernel/test_utils.py
deleted
100644 → 0
View file @
033b75f5
import
torch
def
create_per_token_group_quant_test_data
(
num_tokens
,
hidden_dim
,
num_ranks
,
flags
):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
seed
=
num_tokens
*
10000
+
hidden_dim
gen_cpu
=
torch
.
Generator
(
device
=
"cpu"
)
gen_cpu
.
manual_seed
(
seed
)
gen_cuda
=
torch
.
Generator
(
device
=
"cuda"
)
gen_cuda
.
manual_seed
(
seed
)
if
flags
[
"fuse_silu_and_mul"
]:
effective_hidden_dim
=
hidden_dim
*
2
else
:
effective_hidden_dim
=
hidden_dim
del
hidden_dim
if
(
masked_layout_mode
:
=
flags
[
"masked_layout_mode"
])
is
not
None
:
num_max_dispatch_tokens_per_rank
=
768
num_global_experts
=
288
num_local_experts
,
remainder
=
divmod
(
num_global_experts
,
num_ranks
)
assert
remainder
==
0
# mimic DeepEP low_latency_dispatch output
x
=
torch
.
randn
(
num_local_experts
,
num_max_dispatch_tokens_per_rank
*
num_ranks
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
if
masked_layout_mode
==
"balanced"
:
masked_m
=
_compute_balanced_split
(
num_tokens
,
num_local_experts
)
elif
masked_layout_mode
==
"imbalanced"
:
masked_m
=
_compute_imbalanced_split
(
num_tokens
,
num_local_experts
,
gen_cpu
=
gen_cpu
)
elif
masked_layout_mode
==
"extreme"
:
masked_m
=
torch
.
tensor
(
[
num_tokens
]
+
[
0
]
*
(
num_local_experts
-
1
),
dtype
=
torch
.
int
)
else
:
raise
NotImplementedError
print
(
f
"
{
masked_layout_mode
=
}
{
masked_m
=
}
{
x
.
shape
=
}
"
)
masked_m
=
masked_m
.
to
(
device
)
return
x
,
masked_m
else
:
x
=
torch
.
randn
(
num_tokens
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
x
[
torch
.
randn
(
x
.
shape
,
device
=
device
,
generator
=
gen_cuda
)
<
0.001
]
*=
10
return
x
,
None
def
_compute_balanced_split
(
total
:
int
,
arr_len
:
int
):
base
=
total
//
arr_len
remainder
=
total
%
arr_len
ans
=
[
base
+
1
if
i
<
remainder
else
base
for
i
in
range
(
arr_len
)]
assert
sum
(
ans
)
==
total
return
torch
.
tensor
(
ans
,
dtype
=
torch
.
int
)
def
_compute_imbalanced_split
(
total
:
int
,
arr_len
:
int
,
gen_cpu
,
dtype
=
torch
.
int
)
->
list
[
int
]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw
=
torch
.
rand
(
arr_len
,
generator
=
gen_cpu
)
**
3
noise
=
noise_raw
/
noise_raw
.
sum
()
ans
=
(
noise
*
total
).
round
().
to
(
dtype
)
diff
=
total
-
ans
.
sum
().
item
()
while
diff
!=
0
:
idx
=
torch
.
randint
(
0
,
arr_len
,
(
1
,),
generator
=
gen_cpu
).
item
()
if
diff
>
0
:
ans
[
idx
]
+=
1
diff
-=
1
elif
diff
<
0
and
ans
[
idx
]
>
0
:
ans
[
idx
]
-=
1
diff
+=
1
assert
sum
(
ans
)
==
total
return
ans
def
assert_all_close_or_tiny_diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
(
a
.
shape
==
b
.
shape
)
and
(
a
.
dtype
==
b
.
dtype
),
f
"
{
a
.
shape
=
}
{
b
.
shape
=
}
{
a
.
dtype
=
}
{
b
.
dtype
=
}
"
numel
=
a
.
numel
()
if
a
.
dtype
==
torch
.
float8_e4m3fn
:
a_u8
=
a
.
view
(
torch
.
uint8
)
b_u8
=
b
.
view
(
torch
.
uint8
)
diff_u8
=
(
a_u8
.
to
(
torch
.
int16
)
-
b_u8
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a_u8
>=
0
)
&
(
b_u8
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff_u8
==
1
).
sum
().
item
()
count_large_diff
=
(
diff_u8
>=
2
).
sum
().
item
()
elif
a
.
dtype
==
torch
.
int8
:
diff
=
(
a
.
to
(
torch
.
int16
)
-
a
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a
>=
0
)
&
(
b
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff
==
1
).
sum
().
item
()
count_large_diff
=
(
diff
>=
2
).
sum
().
item
()
else
:
raise
NotImplementedError
assert
(
(
count_diff_sign
==
0
)
and
(
count_large_diff
==
0
)
and
(
(
count_tiny_diff
/
numel
<
0.005
)
or
((
count_tiny_diff
/
numel
<
0.04
)
and
(
numel
<=
4096
))
)
),
f
"
{
count_diff_sign
=
}
{
count_tiny_diff
=
}
{
count_large_diff
=
}
{
numel
=
}
{
a
=
}
{
b
=
}
"
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
6d55f60e
import
itertools
import
itertools
import
os
import
time
from
pathlib
import
Path
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel.test_utils
import
(
assert_all_close_or_tiny_diff
,
create_per_token_group_quant_test_data
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.srt.layers.quantization.utils
import
assert_fp8_all_close
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
configs
=
list
(
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, dst_dtype, flags"
,
list
(
itertools
.
product
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
128
,
256
,
384
,
512
,
1024
,
1536
,
1664
,
2048
,
4096
,
7168
,
16384
],
# hidden_dim
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
16
,
32
,
64
,
128
],
# group_size
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
None
],
# num_ranks
# TODO test int8
[
fp8_type_
,
torch
.
int8
],
# dtype
[
fp8_type_
],
# dtype
[
[
dict
(
dict
(
column_major_scales
=
False
,
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
],
],
)
)
)
+
list
(
itertools
.
product
(
[
1
,
4
,
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
# TODO support more
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
),
],
)
)
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags"
,
configs
)
)
def
test_per_token_group_quant_with_column_major
(
def
test_per_token_group_quant_with_column_major
(
num_tokens
,
num_tokens
,
hidden_dim
,
hidden_dim
,
group_size
,
group_size
,
num_ranks
,
dst_dtype
,
dst_dtype
,
flags
,
flags
,
):
):
print
(
if
flags
[
"scale_ue8m0"
]
and
((
group_size
!=
128
)
or
(
hidden_dim
%
512
!=
0
)):
f
"
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
"
)
arch_major
,
_
=
torch
.
cuda
.
get_device_capability
(
torch
.
cuda
.
current_device
())
if
flags
[
"scale_ue8m0"
]
and
(
arch_major
<=
9
):
pytest
.
skip
(
"Only Blackwell need ue8m0 fusion"
)
return
if
(
flags
[
"scale_ue8m0"
]
and
(
group_size
!=
128
))
or
(
(
dst_dtype
==
torch
.
int8
)
and
flags
[
"column_major_scales"
]
):
pytest
.
skip
()
pytest
.
skip
()
return
return
if
flags
[
"scale_ue8m0"
]
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
:
pytest
.
skip
(
"scale_ue8m0 only supported on Blackwell"
)
return
x
,
masked_m
=
create_per_token_group_quant_test_data
(
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs
=
dict
(
execute_kwargs
=
dict
(
x
=
x
,
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
group_size
=
group_size
,
eps
=
1e-10
,
eps
=
1e-10
,
dst_dtype
=
dst_dtype
,
dst_dtype
=
dst_dtype
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]}
,
**
flags
,
)
)
def
_postprocess
(
x_q
,
x_s
):
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
if
masked_m
is
not
None
:
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
print
(
f
"Mask tokens after
{
masked_m
}
to be zero"
)
for
i
in
range
(
len
(
masked_m
)):
x_q
[
i
,
masked_m
[
i
]
:,
:]
=
0
x_s
[
i
,
masked_m
[
i
]
:,
:]
=
0
return
x_q
,
x_s
x_q_triton
,
x_s_t
rito
n
=
_postprocess
(
# torch.set_p
ri
n
to
ptions(profile="full")
*
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
# print(f"{x_q_triton=}"
)
)
# print(f"{x_s_triton=}"
)
x_q_sglang
,
x_s_sglang
=
_postprocess
(
# print(f"{x_q_sglang=}")
*
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
# print(f"{x_s_sglang=}"
)
)
# torch.set_printoptions(profile="default"
)
try
:
assert_fp8_all_close
(
x_q_triton
,
x_q_sglang
)
assert_all_close_or_tiny_diff
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
x_s_triton
.
contiguous
(),
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
x_s_sglang
.
contiguous
(),
...
@@ -165,35 +91,6 @@ def test_per_token_group_quant_with_column_major(
...
@@ -165,35 +91,6 @@ def test_per_token_group_quant_with_column_major(
atol
=
1e-5
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
)
except
AssertionError
:
# torch.set_printoptions(profile="full")
print
(
f
"
{
x
.
shape
=
}
{
x_q_triton
.
shape
=
}
{
x_s_triton
.
shape
=
}
{
x_q_sglang
.
shape
=
}
{
x_s_sglang
.
shape
=
}
"
)
print
(
f
"
{
x
=
}
"
)
print
(
f
"
{
masked_m
=
}
"
)
print
(
f
"
{
x_q_triton
=
}
"
)
print
(
f
"
{
x_s_triton
=
}
"
)
print
(
f
"
{
x_q_sglang
=
}
"
)
print
(
f
"
{
x_s_sglang
=
}
"
)
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment