Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59b2f7b6
Unverified
Commit
59b2f7b6
authored
Apr 11, 2026
by
Wei Zhao
Committed by
GitHub
Apr 11, 2026
Browse files
[Perf] Fuse Zero Initializer for FP8 DeepGemm Block Quant Kernel (#39547)
Signed-off-by:
wzhao18
<
wzhao18.sz@gmail.com
>
parent
92feb999
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
180 additions
and
49 deletions
+180
-49
csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
...rch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+70
-49
tests/kernels/quantization/test_per_token_group_quant.py
tests/kernels/quantization/test_per_token_group_quant.py
+110
-0
No files found.
csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
View file @
59b2f7b6
...
...
@@ -240,8 +240,9 @@ template <typename T, typename DST_DTYPE>
__global__
void
per_token_group_quant_8bit_packed_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
unsigned
int
*
__restrict__
output_s_packed
,
const
int
group_size
,
const
int
num_groups
,
const
int
groups_per_block
,
const
int
groups_per_row
,
const
int
mn
,
const
int
tma_aligned_mn
,
const
float
eps
,
const
int
num_groups_padded
,
const
int
groups_per_block
,
const
int
padded_groups_per_row
,
const
int
groups_per_row
,
const
int
mn
,
const
int
tma_aligned_mn
,
const
int
num_scale_elems
,
const
float
eps
,
const
float
min_8bit
,
const
float
max_8bit
)
{
const
int
threads_per_group
=
16
;
const
int64_t
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
...
...
@@ -249,51 +250,62 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
const
int64_t
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
if
(
global_group_id
>=
num_groups
)
{
if
(
global_group_id
>=
num_groups
_padded
)
{
return
;
}
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
// map flat group id to 2D indices (mn_idx, sf_k_idx)
const
int
sf_k_idx
=
static_cast
<
int
>
(
global_group_id
%
padded_groups_per_row
);
const
int
mn_idx
=
static_cast
<
int
>
(
global_group_id
/
padded_groups_per_row
);
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
// whether it is a valid group (not padding)
const
bool
is_valid_group
=
(
mn_idx
<
mn
)
&&
(
sf_k_idx
<
groups_per_row
);
// shared memory to cache each group's data to avoid double DRAM reads.
extern
__shared__
__align__
(
16
)
char
smem_raw
[];
T
*
smem
=
reinterpret_cast
<
T
*>
(
smem_raw
);
T
*
smem_group
=
smem
+
local_group_id
*
group_size
;
const
float
y_s
=
ComputeGroupScale
<
T
,
true
>
(
group_input
,
smem_group
,
group_size
,
lane_id
,
threads_per_group
,
eps
,
max_8bit
);
// pack 4 scales into a uint32
if
(
lane_id
==
0
)
{
// map flat group id to 2D indices (mn_idx, sf_k_idx)
const
int
sf_k_idx
=
static_cast
<
int
>
(
global_group_id
%
groups_per_row
);
const
int
mn_idx
=
static_cast
<
int
>
(
global_group_id
/
groups_per_row
);
// compute scale for valid groups
float
y_s
=
0.
f
;
if
(
is_valid_group
)
{
const
T
*
group_input
=
input
+
static_cast
<
int64_t
>
(
mn_idx
)
*
groups_per_row
*
group_size
+
sf_k_idx
*
group_size
;
y_s
=
ComputeGroupScale
<
T
,
true
>
(
group_input
,
smem_group
,
group_size
,
lane_id
,
threads_per_group
,
eps
,
max_8bit
);
}
if
(
mn_idx
<
mn
)
{
// pack 4 scales into a uint32 exponent
if
(
lane_id
==
0
)
{
// each uint32 in output_s_packed stores 4 packed scales
const
int
sf_k_pack_idx
=
sf_k_idx
/
4
;
const
int
pos
=
sf_k_idx
%
4
;
const
int
out_idx
=
sf_k_pack_idx
*
tma_aligned_mn
+
mn_idx
;
if
(
is_valid_group
)
{
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
// exponent, and place it into the correct byte of the 32-bit word.
const
unsigned
int
bits
=
__float_as_uint
(
y_s
);
const
unsigned
int
exponent
=
(
bits
>>
23u
)
&
0xffu
;
const
unsigned
int
contrib
=
exponent
<<
(
pos
*
8u
);
const
int
out_idx
=
sf_k_pack_idx
*
tma_aligned_mn
+
mn_idx
;
// atomically OR 8-bit exponent into the packed scales buffer
atomicOr
(
output_s_packed
+
out_idx
,
contrib
);
const
uint8_t
exponent
=
static_cast
<
uint8_t
>
((
bits
>>
23u
)
&
0xffu
);
reinterpret_cast
<
uint8_t
*>
(
output_s_packed
)[
out_idx
*
4
+
pos
]
=
exponent
;
}
else
if
(
out_idx
<
num_scale_elems
)
{
// write zero for padding groups if within bounds of output_s_packed
reinterpret_cast
<
uint8_t
*>
(
output_s_packed
)[
out_idx
*
4
+
pos
]
=
0
;
}
}
__syncthreads
();
if
(
is_valid_group
)
{
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
static_cast
<
int64_t
>
(
mn_idx
)
*
groups_per_row
*
group_size
+
sf_k_idx
*
group_size
;
QuantizeGroup
<
T
,
DST_DTYPE
>
(
smem_group
,
group_output
,
group_size
,
lane_id
,
threads_per_group
,
y_s
,
min_8bit
,
max_8bit
);
}
}
void
per_token_group_quant_8bit_packed
(
const
torch
::
stable
::
Tensor
&
input
,
...
...
@@ -310,7 +322,6 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
const
int64_t
mn
=
input
.
numel
()
/
k
;
const
int64_t
groups_per_row
=
k
/
group_size
;
const
int64_t
num_groups
=
mn
*
groups_per_row
;
STD_TORCH_CHECK
(
output_s_packed
.
dim
()
==
2
,
"output_s_packed must be 2D, got dim="
,
output_s_packed
.
dim
(),
...
...
@@ -330,21 +341,30 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
"output_s_packed shape must be ["
,
mn
,
", "
,
k_num_packed_sfk
,
"], but got ["
,
output_s_packed
.
size
(
0
),
", "
,
output_s_packed
.
size
(
1
),
"]."
);
// Verify column-major TMA-aligned layout
STD_TORCH_CHECK
(
output_s_packed
.
stride
(
0
)
==
1
&&
output_s_packed
.
stride
(
1
)
==
tma_aligned_mn
,
"output_s_packed must have strides [1, "
,
tma_aligned_mn
,
"], but got ["
,
output_s_packed
.
stride
(
0
),
", "
,
output_s_packed
.
stride
(
1
),
"]."
);
cudaStream_t
stream
=
get_current_cuda_stream
();
constexpr
int
THREADS_PER_GROUP
=
16
;
const
int
groups_per_block
=
GetGroupsPerBlock
(
num_groups
);
// Expand the grid to cover MN and K padding so every byte in
// output_s_packed is written (padding bytes get zeroed by the kernel).
const
int64_t
padded_groups_per_row
=
k_num_packed_sfk
*
4
;
const
int64_t
num_groups_padded
=
tma_aligned_mn
*
padded_groups_per_row
;
// Number of elements in output_s_packed.
const
int64_t
num_scale_elems
=
mn
+
(
k_num_packed_sfk
-
1
)
*
tma_aligned_mn
;
const
int
groups_per_block
=
GetGroupsPerBlock
(
num_groups_padded
);
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_blocks
=
num_groups
_padded
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
// zero-initialize packed scales, since we use atomicOr to accumulate
// exponents from different groups.
torch
::
stable
::
zero_
(
output_s_packed
);
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
...
...
@@ -355,11 +375,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
<<<grid, block, smem_bytes, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
static_cast<int>(group_size), static_cast<int>(num_groups), \
groups_per_block, static_cast<int>(groups_per_row), \
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \
static_cast<int>(group_size), static_cast<int>(num_groups_padded), \
groups_per_block, static_cast<int>(padded_groups_per_row), \
static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<int>(tma_aligned_mn), \
static_cast<int>(num_scale_elems), static_cast<float>(eps), \
static_cast<float>(min_8bit), static_cast<float>(max_8bit)); \
} while (0)
VLLM_STABLE_DISPATCH_FLOATING_TYPES
(
...
...
tests/kernels/quantization/test_per_token_group_quant.py
View file @
59b2f7b6
...
...
@@ -48,6 +48,116 @@ def test_per_token_group_quant_fp8(
assert
torch
.
allclose
(
scale
,
ref_s
,
atol
=
0.01
,
rtol
=
0.01
)
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_dim,group_size"
,
[
# No padding: mn=4 (mult of 4), groups_per_row=56 (mult of 4)
(
4
,
7168
,
128
),
# MN padding only: mn=1, tma_aligned_mn=4
(
1
,
7168
,
128
),
# MN padding only: mn=3, tma_aligned_mn=4
(
3
,
7168
,
128
),
# K padding only: groups_per_row=5 (5%4=1)
(
4
,
640
,
128
),
# K padding only: groups_per_row=6 (6%4=2)
(
4
,
768
,
128
),
# Single packed column, no padding: k_num_packed=1, mn%4=0
(
4
,
384
,
128
),
# Both MN and K padding
(
1
,
384
,
128
),
(
3
,
640
,
128
),
# Larger shapes with no padding
(
64
,
7168
,
128
),
(
128
,
14336
,
128
),
# Larger shapes with padding
(
127
,
7168
,
128
),
(
253
,
640
,
128
),
# Non-power-of-2 group size
(
4
,
768
,
96
),
# 768/96=8 groups, no padding
(
3
,
768
,
96
),
# 768/96=8 groups, MN padding
(
4
,
480
,
96
),
# 480/96=5 groups, K padding
(
1
,
480
,
96
),
# both MN and K padding
],
)
@
pytest
.
mark
.
parametrize
(
"poisoned_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
def
test_per_token_group_quant_fp8_packed
(
num_tokens
,
hidden_dim
,
group_size
,
poisoned_scales
):
"""Test the packed DeepGEMM quantization kernel against the Triton
reference (row-major, UE8M0 scales)."""
device
=
"cuda"
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
((
num_tokens
,
hidden_dim
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
*
8
mn
=
num_tokens
groups_per_row
=
hidden_dim
//
group_size
k_num_packed
=
(
groups_per_row
+
3
)
//
4
tma_aligned_mn
=
((
mn
+
3
)
//
4
)
*
4
num_scale_elems
=
mn
+
(
k_num_packed
-
1
)
*
tma_aligned_mn
if
poisoned_scales
:
# Call the kernel with poisoned scale buffer to
# ensure padded indices are correctly zeroed.
fp8_dtype
=
torch
.
float8_e4m3fn
finfo
=
torch
.
finfo
(
fp8_dtype
)
out_q
=
torch
.
empty_like
(
x
,
dtype
=
fp8_dtype
)
out_s_packed
=
torch
.
empty_strided
(
(
mn
,
k_num_packed
),
(
1
,
tma_aligned_mn
),
device
=
device
,
dtype
=
torch
.
int32
,
)
torch
.
as_strided
(
out_s_packed
,
(
num_scale_elems
,),
(
1
,)).
fill_
(
0x7F7F7F7F
)
torch
.
ops
.
_C
.
per_token_group_fp8_quant_packed
(
x
,
out_q
,
out_s_packed
,
group_size
,
1e-10
,
finfo
.
min
,
finfo
.
max
,
)
else
:
out_q
,
out_s_packed
=
fp8_utils
.
per_token_group_quant_fp8_packed_for_deepgemm
(
x
,
group_size
=
group_size
,
use_ue8m0
=
True
,
)
# Triton reference (row-major float32 scales, UE8M0)
with
patch
(
"vllm.platforms.current_platform.is_cuda"
,
return_value
=
False
):
ref_q
,
ref_s
=
fp8_utils
.
per_token_group_quant_fp8
(
x
,
group_size
,
use_ue8m0
=
True
,
)
# Quantized values must match.
assert
torch
.
equal
(
out_q
,
ref_q
),
"Quantized output mismatch"
# Verify packed scales (valid exponents + padding zeros).
ref_s_flat
=
ref_s
.
reshape
(
mn
,
groups_per_row
)
ref_exponents
=
(
ref_s_flat
.
view
(
torch
.
int32
)
>>
23
)
&
0xFF
expected
=
torch
.
zeros
(
num_scale_elems
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
row
in
range
(
mn
):
for
g
in
range
(
groups_per_row
):
pack_col
=
g
//
4
pos
=
g
%
4
idx
=
pack_col
*
tma_aligned_mn
+
row
expected
[
idx
]
|=
int
(
ref_exponents
[
row
,
g
].
item
())
<<
(
pos
*
8
)
actual
=
torch
.
as_strided
(
out_s_packed
,
(
num_scale_elems
,),
(
1
,)).
cpu
()
assert
torch
.
equal
(
actual
,
expected
),
(
f
"Packed scale storage mismatch.
\n
"
f
"First diff at index "
f
"
{
(
actual
!=
expected
).
nonzero
(
as_tuple
=
True
)[
0
][
0
].
item
()
}
"
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
128
),
(
64
,
256
),
(
16
,
512
)])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment