Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d08b356e
Unverified
Commit
d08b356e
authored
Jan 22, 2026
by
Xin Yang
Committed by
GitHub
Jan 22, 2026
Browse files
[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
f7448101
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
75 additions
and
17 deletions
+75
-17
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
...hmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+3
-3
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+3
-0
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+30
-9
tests/kernels/quantization/test_per_token_group_quant.py
tests/kernels/quantization/test_per_token_group_quant.py
+6
-2
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+5
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+23
-3
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+5
-0
No files found.
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
View file @
d08b356e
...
...
@@ -14,7 +14,6 @@ from vllm.triton_utils import triton
from
vllm.utils.deep_gemm
import
(
calc_diff
,
fp8_gemm_nt
,
get_col_major_tma_aligned_tensor
,
per_block_cast_to_fp8
,
)
...
...
@@ -48,8 +47,9 @@ def benchmark_shape(
block_size
=
[
128
,
128
]
# Pre-quantize A for all implementations
A_deepgemm
,
A_scale_deepgemm
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
])
A_scale_deepgemm
=
get_col_major_tma_aligned_tensor
(
A_scale_deepgemm
)
A_deepgemm
,
A_scale_deepgemm
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
],
column_major_scales
=
True
,
tma_aligned_scales
=
True
)
C_deepgemm
=
torch
.
empty
((
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
A_vllm
,
A_scale_vllm
=
per_token_group_quant_fp8
(
A
,
block_size
[
1
])
A_vllm_cutlass
,
A_scale_vllm_cutlass
=
per_token_group_quant_fp8
(
...
...
tests/kernels/quant_utils.py
View file @
d08b356e
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
_ceil_to_ue8m0
,
is_deep_gemm_e8m0_used
from
vllm.utils.math_utils
import
round_up
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8(
x_
=
x
.
reshape
(
x
.
numel
()
//
group_size
,
group_size
)
amax
=
x_
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
clamp
(
min
=
eps
).
to
(
torch
.
float32
)
x_s
=
amax
/
fp8_max
if
is_deep_gemm_e8m0_used
():
x_s
=
_ceil_to_ue8m0
(
x_s
)
x_q
=
(
x_
/
x_s
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
)
x_q
=
x_q
.
reshape
(
x
.
shape
)
x_s
=
x_s
.
reshape
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,))
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
d08b356e
...
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
get_
col_major_
tma_aligned_
tensor
,
get_tma_aligned_
size
,
per_block_cast_to_fp8
,
should_use_deepgemm_for_fp8_linear
,
)
...
...
@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS
=
[
7
,
2050
]
D
=
[
512
,
4096
,
5120
,
13824
]
GROUP_SIZE
=
[
64
,
128
,
512
]
COLUMN_MAJOR_SCALES
=
[
True
,
False
]
TMA_ALIGNED_SCALES
=
[
True
,
False
]
M
=
[
1
,
7
,
8
,
83
,
84
,
4096
]
N
=
[
128
,
512
,
7168
,
7748
,
13824
]
K
=
[
256
,
3884
,
4096
,
13824
,
16384
]
...
...
@@ -63,20 +65,40 @@ def setup_cuda():
reason
=
"This platform supports e4m3fnuz, not e4m3fn."
,
)
@
pytest
.
mark
.
parametrize
(
"num_tokens,d,dtype,group_size,seed"
,
itertools
.
product
(
NUM_TOKENS
,
D
,
DTYPES
,
GROUP_SIZE
,
SEEDS
),
"num_tokens,d,dtype,group_size,column_major_scales,tma_aligned_scales,seed"
,
itertools
.
product
(
NUM_TOKENS
,
D
,
DTYPES
,
GROUP_SIZE
,
COLUMN_MAJOR_SCALES
,
TMA_ALIGNED_SCALES
,
SEEDS
,
),
)
@
torch
.
inference_mode
()
def
test_per_token_group_quant_fp8
(
num_tokens
,
d
,
dtype
,
group_size
,
seed
):
def
test_per_token_group_quant_fp8
(
num_tokens
,
d
,
dtype
,
group_size
,
column_major_scales
,
tma_aligned_scales
,
seed
):
torch
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
d
,
dtype
=
dtype
)
ref_out
,
ref_scale
=
native_per_token_group_quant_fp8
(
x
,
group_size
)
out
,
scale
=
per_token_group_quant_fp8
(
x
,
group_size
)
out
,
scale
=
per_token_group_quant_fp8
(
x
,
group_size
,
column_major_scales
=
column_major_scales
,
tma_aligned_scales
=
tma_aligned_scales
,
)
assert
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.15
)
assert
torch
.
allclose
(
scale
,
ref_scale
)
if
column_major_scales
:
assert
scale
.
stride
()[
-
2
]
==
1
if
tma_aligned_scales
:
assert
scale
.
stride
()[
-
1
]
==
get_tma_aligned_size
(
num_tokens
,
4
)
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
...
...
@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
):
pytest
.
skip
(
f
"Skipping test; invalid size
{
M
}
,
{
N
}
,
{
K
}
"
)
A_fp8
,
As_fp8
=
per_token_group_quant_fp8
(
A_fp32
,
block_size
[
1
])
A_fp8
,
As_fp8
=
per_token_group_quant_fp8
(
A_fp32
,
block_size
[
1
],
column_major_scales
=
True
,
tma_aligned_scales
=
True
)
B_fp8
,
Bs_fp8
=
per_block_cast_to_fp8
(
B_fp32
,
block_size
=
block_size
)
As
=
As_fp8
.
to
(
torch
.
float32
)
...
...
@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out
=
native_w8a8_block_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
out_dtype
)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8
=
get_col_major_tma_aligned_tensor
(
As_fp8
)
out
=
torch
.
zeros
((
M
,
N
),
device
=
"cuda"
,
dtype
=
out_dtype
)
assert
As_fp8
.
shape
==
(
M
,
(
K
+
127
)
//
128
),
(
...
...
tests/kernels/quantization/test_per_token_group_quant.py
View file @
d08b356e
...
...
@@ -8,13 +8,16 @@ import torch
from
vllm.model_executor.layers.quantization.utils
import
fp8_utils
,
int8_utils
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
128
),
(
64
,
256
),
(
16
,
512
)])
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
31
,
128
),
(
32
,
128
),
(
63
,
256
),
(
64
,
256
),
(
16
,
512
)]
)
@
pytest
.
mark
.
parametrize
(
"column_major"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"tma_aligned"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"scale_ue8m0"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
def
test_per_token_group_quant_fp8
(
shape
,
column_major
:
bool
,
scale_ue8m0
:
bool
,
group_size
:
int
shape
,
column_major
:
bool
,
tma_aligned
:
bool
,
scale_ue8m0
:
bool
,
group_size
:
int
):
device
=
"cuda"
...
...
@@ -28,6 +31,7 @@ def test_per_token_group_quant_fp8(
x
,
group_size
,
column_major_scales
=
column_major
,
tma_aligned_scales
=
tma_aligned
,
use_ue8m0
=
scale_ue8m0
,
)
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
d08b356e
...
...
@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
group_shape
:
GroupShape
,
num_token_padding
:
int
|
None
=
None
,
column_major_scales
:
bool
=
False
,
tma_aligned_scales
:
bool
=
False
,
use_ue8m0
:
bool
|
None
=
None
,
# for Torch compile
):
"""
...
...
@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in
column major format
"""
...
...
@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
self
.
use_per_token_if_dynamic
=
group_shape
==
GroupShape
.
PER_TOKEN
self
.
num_token_padding
=
num_token_padding
self
.
column_major_scales
=
column_major_scales
self
.
tma_aligned_scales
=
tma_aligned_scales
self
.
use_ue8m0
=
use_ue8m0
self
.
use_aiter
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
...
...
@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
x
,
group_size
=
self
.
group_size
,
column_major_scales
=
self
.
column_major_scales
,
tma_aligned_scales
=
self
.
tma_aligned_scales
,
dtype
=
_FP8_DTYPE
,
use_ue8m0
=
self
.
use_ue8m0
,
)
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
d08b356e
...
...
@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
from
vllm.utils.deep_gemm
import
(
DeepGemmQuantScaleFMT
,
fp8_gemm_nt
,
get_tma_aligned_size
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
should_use_deepgemm_for_fp8_linear
,
...
...
@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
False
,
self
.
act_quant_group_shape
,
column_major_scales
=
True
,
tma_aligned_scales
=
True
,
use_ue8m0
=
self
.
use_deep_gemm_e8m0
,
)
if
self
.
is_deep_gemm_supported
...
...
@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
|
None
=
None
,
column_major_scales
:
bool
=
False
,
tma_aligned_scales
:
bool
=
False
,
out_q
:
torch
.
Tensor
|
None
=
None
,
use_ue8m0
:
bool
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
dtype: The d
t
ype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
tma_aligned_scales: Outputs scales in TMA-aligned layout.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
...
...
@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
# Allocate the scale tensor in either row- or column-major format.
if
column_major_scales
:
shape
=
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
]
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
if
tma_aligned_scales
:
m
=
x
.
shape
[
-
2
]
sf_k
=
x
.
shape
[
-
1
]
//
group_size
tma_aligned_m
=
get_tma_aligned_size
(
m
,
4
)
shape
=
x
.
shape
[:
-
2
]
+
(
m
,
sf_k
)
stride
=
(
(
1
,
tma_aligned_m
)
if
x
.
dim
()
==
2
else
(
tma_aligned_m
*
sf_k
,
1
,
tma_aligned_m
)
)
x_s
=
torch
.
empty_strided
(
shape
,
stride
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
else
:
shape
=
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
x
.
shape
[
-
2
])
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
else
:
shape
=
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,)
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
...
...
vllm/utils/deep_gemm.py
View file @
d08b356e
...
...
@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
return
cdiv
(
x
,
y
)
*
y
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
def
get_tma_aligned_size
(
x
:
int
,
element_size
:
int
):
return
_align
(
x
,
16
//
element_size
)
DEFAULT_BLOCK_SIZE
=
[
128
,
128
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment