Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42d440c2
Unverified
Commit
42d440c2
authored
Jul 12, 2025
by
Wentao Ye
Committed by
GitHub
Jul 12, 2025
Browse files
[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
f45a3328
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
42 deletions
+26
-42
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+4
-3
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+2
-3
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+7
-6
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+1
-6
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+12
-3
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+0
-21
No files found.
tests/kernels/moe/test_deepgemm.py
View file @
42d440c2
...
...
@@ -13,9 +13,10 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
calc_diff
,
per_block_cast_to_fp8
,
per_token_group_cast_to_fp8
)
from
vllm.utils.deep_gemm
import
calc_diff
,
per_block_cast_to_fp8
BLOCK_SIZE
=
[
128
,
128
]
...
...
@@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
clamp_min_
(
-
1
).
clamp_max_
(
1
)
_
,
a1_scale
=
per_token_group_
cast_to
_fp8
(
tokens_bf16
,
block_size
[
1
])
_
,
a1_scale
=
per_token_group_
quant
_fp8
(
tokens_bf16
,
block_size
[
1
])
# expert weight tensors
w1
,
w2
,
w1_s
,
w2_s
=
make_block_quant_fp8_weights
(
num_experts
,
n
,
k
,
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
42d440c2
...
...
@@ -15,8 +15,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
per_block_cast_to_fp8
,
per_token_group_cast_to_fp8
)
from
vllm.utils.deep_gemm
import
fp8_gemm_nt
,
per_block_cast_to_fp8
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
...
...
@@ -117,7 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
A_fp8
,
As_fp8
=
per_token_group_
cast_to
_fp8
(
A_fp32
,
block_size
[
1
])
A_fp8
,
As_fp8
=
per_token_group_
quant
_fp8
(
A_fp32
,
block_size
[
1
])
B_fp8
,
Bs_fp8
=
per_block_cast_to_fp8
(
B_fp32
)
As
=
As_fp8
.
to
(
torch
.
float32
)
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
42d440c2
...
...
@@ -15,9 +15,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
has_deep_gemm
,
round_up
from
vllm.utils.deep_gemm
import
(
m_grouped_fp8_gemm_nt_contiguous
,
per_token_group_cast_to_fp8
)
from
vllm.utils.deep_gemm
import
m_grouped_fp8_gemm_nt_contiguous
logger
=
init_logger
(
__name__
)
...
...
@@ -170,10 +171,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
activation
(
activation
,
act_out
,
mm1_out
.
view
(
-
1
,
N
))
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
a2q
,
a2q_scale
=
per_token_group_
cast_to
_fp8
(
act_out
,
self
.
block_shape
[
1
],
column_major_scales
=
True
,
out_q
=
quant_out
)
a2q
,
a2q_scale
=
per_token_group_
quant
_fp8
(
act_out
,
self
.
block_shape
[
1
],
column_major_scales
=
True
,
out_q
=
quant_out
)
m_grouped_fp8_gemm_nt_contiguous
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
42d440c2
...
...
@@ -15,8 +15,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
from
vllm.utils.deep_gemm
import
(
is_blackwell_deep_gemm_used
,
per_token_group_cast_to_fp8
)
@
triton
.
jit
...
...
@@ -119,10 +117,7 @@ def _fp8_quantize(
assert
not
per_act_token
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
if
is_blackwell_deep_gemm_used
():
A
,
A_scale
=
per_token_group_cast_to_fp8
(
A
,
block_k
)
else
:
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
cdiv
(
A
.
size
(
-
1
),
block_k
)
==
A_scale
.
size
(
-
1
)
return
A
,
A_scale
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
42d440c2
...
...
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
,
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
logger
=
init_logger
(
__name__
)
...
...
@@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
# Information for float8
fp8_min
,
fp8_max
,
use_ue8m0
:
tl
.
constexpr
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
...
...
@@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
scale_raw
=
_absmax
/
fp8_max
y_s
=
tl
.
math
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
scale_raw
)))
if
use_ue8m0
else
scale_raw
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
...
...
@@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
# Information for float8
fp8_min
,
fp8_max
,
use_ue8m0
:
tl
.
constexpr
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
...
...
@@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
scale_raw
=
_absmax
/
fp8_max
y_s
=
tl
.
math
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
scale_raw
)))
if
use_ue8m0
else
scale_raw
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
...
...
@@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
dtype
=
current_platform
.
fp8_dtype
()
if
dtype
is
None
else
dtype
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
...
...
@@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
is_blackwell_deep_gemm_used
(),
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
is_blackwell_deep_gemm_used
(),
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
vllm/utils/deep_gemm.py
View file @
42d440c2
...
...
@@ -49,7 +49,6 @@ if not has_deep_gemm():
_fp8_gemm_nt_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_masked_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_token_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_block_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
else
:
_dg
=
importlib
.
import_module
(
"deep_gemm"
)
# type: ignore
...
...
@@ -74,12 +73,9 @@ else:
try
:
_math_mod
=
importlib
.
import_module
(
"deep_gemm.utils.math"
)
# type: ignore
_per_token_cast_impl
=
getattr
(
_math_mod
,
"per_token_cast_to_fp8"
,
None
)
_per_block_cast_impl
=
getattr
(
_math_mod
,
"per_block_cast_to_fp8"
,
None
)
except
ModuleNotFoundError
:
_per_token_cast_impl
=
None
_per_block_cast_impl
=
None
...
...
@@ -101,22 +97,6 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
return
_grouped_masked_impl
(
*
args
,
**
kwargs
)
def
per_token_group_cast_to_fp8
(
x
,
group_size
,
*
args
,
**
kwargs
):
"""Wrapper for token-wise FP8 quantisation.
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
"""
if
_per_token_cast_impl
is
not
None
and
is_blackwell_deep_gemm_used
():
assert
group_size
==
128
,
"group_size must be 128 for deepgemm"
return
_per_token_cast_impl
(
x
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
as
_ptg
)
return
_ptg
(
x
,
group_size
,
*
args
,
**
kwargs
)
def
per_block_cast_to_fp8
(
x
,
*
args
,
**
kwargs
):
if
_per_block_cast_impl
is
not
None
and
is_blackwell_deep_gemm_used
():
return
_per_block_cast_impl
(
x
)
...
...
@@ -146,7 +126,6 @@ __all__ = [
"fp8_gemm_nt"
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"fp8_m_grouped_gemm_nt_masked"
,
"per_token_group_cast_to_fp8"
,
"per_block_cast_to_fp8"
,
"is_blackwell_deep_gemm_used"
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment