Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1f291412
Unverified
Commit
1f291412
authored
Sep 24, 2025
by
Wentao Ye
Committed by
GitHub
Sep 24, 2025
Browse files
[Refactor] Use DeepGEMM Col Major TMA Aligned Tensor (#25517)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
6160ba41
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
78 deletions
+34
-78
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
...hmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+6
-2
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+4
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+6
-4
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+1
-65
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+14
-1
No files found.
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
View file @
1f291412
...
...
@@ -8,12 +8,16 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
get_col_major_tma_aligned_tensor
,
per_token_group_quant_fp8
,
w8a8_triton_block_scaled_mm
,
)
from
vllm.triton_utils
import
triton
from
vllm.utils.deep_gemm
import
calc_diff
,
fp8_gemm_nt
,
per_block_cast_to_fp8
from
vllm.utils.deep_gemm
import
(
calc_diff
,
fp8_gemm_nt
,
get_col_major_tma_aligned_tensor
,
per_block_cast_to_fp8
,
)
def
benchmark_shape
(
m
:
int
,
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
1f291412
...
...
@@ -11,11 +11,12 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
cutlass_scaled_mm
,
get_col_major_tma_aligned_tensor
,
per_token_group_quant_fp8
,
w8a8_triton_block_scaled_mm
)
cutlass_scaled_mm
,
per_token_group_quant_fp8
,
w8a8_triton_block_scaled_mm
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
fp8_gemm_nt
,
per_block_cast_to_fp8
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
get_col_major_tma_aligned_tensor
,
per_block_cast_to_fp8
)
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
1f291412
...
...
@@ -34,8 +34,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
reorder_w1w3_to_w3w1
,
select_nvfp4_gemm_impl
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
expert_weight_is_col_major
,
get_col_major_tma_aligned_tensor
,
requant_weight_ue8m0_inplace
)
expert_weight_is_col_major
,
requant_weight_ue8m0_inplace
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
)
...
...
@@ -50,7 +49,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
vllm.utils.deep_gemm
import
(
get_col_major_tma_aligned_tensor
,
is_deep_gemm_e8m0_used
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
1f291412
...
...
@@ -34,9 +34,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp
,
check_aiter_fp8_linear_support
,
create_fp8_input_scale
,
create_fp8_scale_parameter
,
create_fp8_weight_parameter
,
expert_weight_is_col_major
,
get_col_major_tma_aligned_tensor
,
maybe_post_process_fp8_weight_block
,
process_fp8_weight_
block
_strategy
,
process_fp8_weight_tensor_strategy
,
requant_weight_ue8m0_inplace
,
validate_fp8_block_shape
)
maybe_post_process_fp8_weight_block
,
process_fp8_weight_block_strategy
,
process_fp8_weight_
tensor
_strategy
,
requant_weight_ue8m0_inplace
,
validate_fp8_block_shape
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_moe_fp8_layer_for_marlin
)
...
...
@@ -53,7 +53,9 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
from
vllm.utils.deep_gemm
import
(
get_col_major_tma_aligned_tensor
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
)
from
vllm.utils.flashinfer
import
has_flashinfer_moe
if
TYPE_CHECKING
:
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
1f291412
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
should_use_deepgemm_for_fp8_linear
)
...
...
@@ -749,70 +749,6 @@ def w8a8_triton_block_scaled_mm(
return
C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_tma_aligned_size
(
x
:
int
,
element_size
:
int
)
->
int
:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes
=
16
assert
tma_alignment_bytes
%
element_size
==
0
alignment
=
tma_alignment_bytes
//
element_size
return
cdiv
(
x
,
alignment
)
*
alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert
x
.
dim
()
in
(
2
,
3
)
remove_dim
=
False
m
,
n
=
x
.
shape
[
-
2
],
x
.
shape
[
-
1
]
aligned_m
=
get_tma_aligned_size
(
m
,
x
.
element_size
())
if
x
.
dim
()
==
2
:
if
x
.
stride
(
0
)
==
1
and
x
.
stride
(
1
)
==
aligned_m
:
return
x
x
,
remove_dim
=
x
.
unsqueeze
(
0
),
True
b
=
x
.
shape
[
0
]
# The last kernel gives a column-major TMA aligned layout
if
x
.
stride
(
0
)
==
aligned_m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
aligned_m
:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
def
requant_weight_ue8m0_inplace
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
...
...
vllm/utils/deep_gemm.py
View file @
1f291412
...
...
@@ -70,11 +70,13 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
_fp8_gemm_nt_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_masked_impl
:
Callable
[...,
Any
]
|
None
=
None
_get_mn_major_tma_aligned_tensor_impl
:
Callable
[...,
Any
]
|
None
=
None
def
_lazy_init
()
->
None
:
"""Import deep_gemm and resolve symbols on first use."""
global
_fp8_gemm_nt_impl
,
_grouped_impl
,
_grouped_masked_impl
global
_fp8_gemm_nt_impl
,
_grouped_impl
,
_grouped_masked_impl
,
\
_get_mn_major_tma_aligned_tensor_impl
# fast path
if
(
_fp8_gemm_nt_impl
is
not
None
or
_grouped_impl
is
not
None
...
...
@@ -95,6 +97,16 @@ def _lazy_init() -> None:
_fp8_gemm_nt_impl
=
getattr
(
_dg
,
"fp8_gemm_nt"
,
None
)
_grouped_impl
=
getattr
(
_dg
,
"m_grouped_fp8_gemm_nt_contiguous"
,
None
)
_grouped_masked_impl
=
getattr
(
_dg
,
"fp8_m_grouped_gemm_nt_masked"
,
None
)
_get_mn_major_tma_aligned_tensor_impl
=
getattr
(
_dg
,
"get_mn_major_tma_aligned_tensor"
,
None
)
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
_lazy_init
()
if
_get_mn_major_tma_aligned_tensor_impl
is
None
:
return
_missing
()
return
_get_mn_major_tma_aligned_tensor_impl
(
x
)
def
fp8_gemm_nt
(
*
args
,
**
kwargs
):
...
...
@@ -191,4 +203,5 @@ __all__ = [
"is_deep_gemm_e8m0_used"
,
"is_deep_gemm_supported"
,
"should_use_deepgemm_for_fp8_linear"
,
"get_col_major_tma_aligned_tensor"
,
]
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment