Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e2de455c
Unverified
Commit
e2de455c
authored
Jul 10, 2025
by
Wentao Ye
Committed by
GitHub
Jul 10, 2025
Browse files
[Feature] Integrate SM100 DeepGEMM support (#20087)
parent
5b032352
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
397 additions
and
114 deletions
+397
-114
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+3
-0
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+8
-8
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+5
-0
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+8
-47
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+11
-16
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+9
-12
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+11
-11
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+9
-3
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+0
-1
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+5
-2
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+6
-1
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+1
-2
vllm/model_executor/layers/quantization/deepgemm.py
vllm/model_executor/layers/quantization/deepgemm.py
+3
-5
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+42
-4
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+124
-2
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+152
-0
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
e2de455c
...
...
@@ -86,6 +86,9 @@ def benchmark_config(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_deep_gemm
:
# we use the default block shape for deepgemm
block_quant_shape
=
[
128
,
128
]
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
...
...
tests/kernels/moe/test_block_fp8.py
View file @
e2de455c
...
...
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
dg_available
=
False
try
:
import
deep_gemm
dg_available
=
True
except
ImportError
:
pass
dg_available
=
has_deep_gemm
()
if
dg_available
:
from
deep_gemm
import
get_m_alignment_for_contiguous_layout
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
...
...
@@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
pytest
.
mark
.
skipif
(
is_blackwell_deep_gemm_used
(),
reason
=
"Not E8M0 scale MOE"
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
monkeypatch
):
...
...
@@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
torch
.
manual_seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
str
(
chunk_size
))
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_m
=
get_m_alignment_for_contiguous_layout
()
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
e2de455c
...
...
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
...
...
@@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
requires_deep_ep
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_blackwell_deep_gemm_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ht_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
"""
...
...
@@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
requires_deep_ep
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_blackwell_deep_gemm_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
...
...
tests/kernels/moe/test_deepgemm.py
View file @
e2de455c
...
...
@@ -13,48 +13,18 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.
model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant
_fp8
)
from
vllm.utils
import
cdiv
from
vllm.
utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
calc_diff
,
per_block_cast_to
_fp8
,
per_token_group_cast_to_fp8
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
if
has_deep_gemm
:
import
deep_gemm
BLOCK_M
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
BLOCK_SIZE
=
[
BLOCK_M
,
BLOCK_M
]
BLOCK_SIZE
=
[
128
,
128
]
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
has_deep_gemm
,
not
has_deep_gemm
()
,
reason
=
"Requires deep_gemm kernels"
,
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
cdiv
(
m
,
128
)
*
128
,
cdiv
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
e
:
int
,
n
:
int
,
...
...
@@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
clamp_min_
(
-
1
).
clamp_max_
(
1
)
_
,
a1_scale
=
per_token_group_
quant
_fp8
(
tokens_bf16
,
block_size
[
1
])
_
,
a1_scale
=
per_token_group_
cast_to
_fp8
(
tokens_bf16
,
block_size
[
1
])
# expert weight tensors
w1
,
w2
,
w1_s
,
w2_s
=
make_block_quant_fp8_weights
(
num_experts
,
n
,
k
,
...
...
@@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape
=
block_size
,
allow_deep_gemm
=
True
,
)
base
=
out_triton
.
abs
().
mean
()
atol
=
0.1
*
base
.
clamp
(
min
=
1e-2
)
# 10% of mean, but not lower than 1e-3
rtol
=
0.05
# ----- Compare -----
torch
.
testing
.
assert_close
(
out_deepgemm
.
to
(
torch
.
float32
),
out_triton
.
to
(
torch
.
float32
),
rtol
=
rtol
,
atol
=
float
(
atol
),
)
diff
=
calc_diff
(
out_deepgemm
,
out_triton
)
assert
diff
<
0.001
,
f
"Diff exceeded 1%:
{
diff
}
"
# Note: W1 has shape (E, 2N, K), so N = 512
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
e2de455c
...
...
@@ -8,19 +8,15 @@ import pytest
import
torch
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_fp8
,
native_w8a8_block_matmul
,
per_block_cast_to_fp8
)
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
get_col_major_tma_aligned_tensor
,
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
dg_available
=
False
try
:
import
deep_gemm
dg_available
=
True
except
ImportError
:
pass
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
per_block_cast_to_fp8
,
per_token_group_cast_to_fp8
)
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
...
...
@@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
pytest
.
mark
.
skipif
(
not
has_deep_gemm
(),
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
# only aligned sizes
...
...
@@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
_
,
block_k
=
block_size
[
0
],
block_size
[
1
]
A_fp8
,
As_fp8
=
per_token_group_quant_fp8
(
A_fp32
,
block_k
)
A_fp8
,
As_fp8
=
per_token_group_cast_to_fp8
(
A_fp32
,
block_size
[
1
])
B_fp8
,
Bs_fp8
=
per_block_cast_to_fp8
(
B_fp32
)
As
=
As_fp8
.
to
(
torch
.
float32
)
...
...
@@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
out_dtype
)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8
=
deep_gemm
.
get_col_major_tma_aligned_tensor
(
As_fp8
)
As_fp8
=
get_col_major_tma_aligned_tensor
(
As_fp8
)
out
=
torch
.
zeros
((
M
,
N
),
device
=
'cuda'
,
dtype
=
out_dtype
)
assert
As_fp8
.
shape
==
(
M
,
(
K
+
127
)
//
128
),
f
"
{
As_fp8
.
shape
}
!=
{
(
M
,
(
K
+
127
)
//
128
)
}
"
deep
_gemm
.
gemm_fp8_fp8_bf16
_nt
((
A_fp8
,
As_fp8
),
(
B_fp8
,
Bs_fp8
),
out
)
fp8
_gemm_nt
((
A_fp8
,
As_fp8
),
(
B_fp8
,
Bs_fp8
),
out
)
rel_diff
=
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
e2de455c
...
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
fp8_m_grouped_gemm_nt_masked
logger
=
init_logger
(
__name__
)
...
...
@@ -271,7 +272,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
expert_tokens_meta
is
not
None
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
import
deep_gemm
as
dg
assert
hidden_states
.
ndim
==
3
assert
self
.
block_shape
is
not
None
...
...
@@ -289,18 +289,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m
=
max_num_tokens
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
out
=
workspace1
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
fp8_m_grouped_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
out
=
workspace1
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
expert_num_tokens
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
out
=
output
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
fp8_m_grouped_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
out
=
output
,
masked_m
=
expert_num_tokens
,
expected_m
=
expected_m
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
e2de455c
...
...
@@ -14,9 +14,10 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.utils
import
has_deep_gemm
,
round_up
from
vllm.utils.deep_gemm
import
(
m_grouped_fp8_gemm_nt_contiguous
,
per_token_group_cast_to_fp8
)
logger
=
init_logger
(
__name__
)
...
...
@@ -127,7 +128,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
):
import
deep_gemm
as
dg
assert
self
.
block_shape
is
not
None
a1q
=
hidden_states
...
...
@@ -164,19 +164,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(
M_sum
,
N
//
2
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
dg
.
m_grouped_gemm_
fp8_fp8_bf16_
nt_contiguous
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
m_grouped_
fp8_
gemm_nt_contiguous
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
self
.
activation
(
activation
,
act_out
,
mm1_out
.
view
(
-
1
,
N
))
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
a2q
,
a2q_scale
=
per_token_group_
quant
_fp8
(
act_out
,
self
.
block_shape
[
1
],
column_major_scales
=
True
,
out_q
=
quant_out
)
a2q
,
a2q_scale
=
per_token_group_
cast_to
_fp8
(
act_out
,
self
.
block_shape
[
1
],
column_major_scales
=
True
,
out_q
=
quant_out
)
dg
.
m_grouped_gemm_
fp8_fp8_bf16_
nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
m_grouped_
fp8_
gemm_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
torch
.
index_select
(
mm2_out
,
0
,
inv_perm
,
out
=
output
.
view
((
-
1
,
K
)))
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
e2de455c
...
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
from
.rocm_aiter_fused_moe
import
is_rocm_aiter_moe_enabled
...
...
@@ -1171,9 +1172,15 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases becuase they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
N
=
w1
.
size
(
1
)
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
N
>
512
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
)):
should_use_deep_gemm
=
((
N
>
512
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
))
or
is_blackwell_deep_gemm_used
())
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
should_use_deep_gemm
):
assert
apply_router_weight_on_input
is
False
return
deep_gemm_moe_fp8
(
hidden_states
=
hidden_states
,
...
...
@@ -1363,7 +1370,6 @@ def fused_experts_impl(
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
qcurr_hidden_states
,
a1q_scale
=
moe_kernel_quantize_input
(
A
=
curr_hidden_states
,
A_scale
=
a1_scale
,
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
e2de455c
...
...
@@ -48,7 +48,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert
topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
a1
.
mul_
(
topk_weights
.
to
(
a1
.
dtype
))
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
a1
,
a1_scale
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
)
...
...
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
View file @
e2de455c
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
_valid_deep_gemm
,
_valid_deep_gemm_shape
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
class
TritonOrDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
...
@@ -102,7 +103,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if
self
.
allow_deep_gemm
and
_valid_deep_gemm_shape
(
M
,
N
,
K
):
if
self
.
allow_deep_gemm
and
(
_valid_deep_gemm_shape
(
M
,
N
,
K
)
or
is_blackwell_deep_gemm_used
()):
assert
self
.
deep_gemm_expert
is
not
None
return
self
.
deep_gemm_expert
.
workspace_shapes
(
a
,
aq
,
M
,
N
,
K
,
topk
,
global_num_experts
,
local_num_experts
)
...
...
@@ -132,7 +134,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
):
use_deep_gemm
=
(
self
.
allow_deep_gemm
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
))
and
(
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
)
or
is_blackwell_deep_gemm_used
()))
experts
=
self
.
deep_gemm_expert
if
use_deep_gemm
else
self
.
triton_expert
assert
experts
is
not
None
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
e2de455c
...
...
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
from
vllm.utils.deep_gemm
import
(
is_blackwell_deep_gemm_used
,
per_token_group_cast_to_fp8
)
@
triton
.
jit
...
...
@@ -115,7 +117,10 @@ def _fp8_quantize(
assert
not
per_act_token
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
if
is_blackwell_deep_gemm_used
():
A
,
A_scale
=
per_token_group_cast_to_fp8
(
A
,
block_k
)
else
:
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
cdiv
(
A
.
size
(
-
1
),
block_k
)
==
A_scale
.
size
(
-
1
)
return
A
,
A_scale
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
e2de455c
...
...
@@ -8,10 +8,9 @@ from typing import Optional, Union
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
()
...
...
vllm/model_executor/layers/quantization/deepgemm.py
View file @
e2de455c
...
...
@@ -6,10 +6,8 @@ import torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils
import
direct_register_custom_op
,
has_deep_gemm
if
has_deep_gemm
():
import
deep_gemm
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils.deep_gemm
import
fp8_gemm_nt
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm(
output_dtype
)
# Deepgemm only supports output tensor type as bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
deep
_gemm
.
gemm_fp8_fp8_bf16
_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
fp8
_gemm_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
return
C
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
e2de455c
...
...
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
get_col_major_tma_aligned_tensor
,
requant_weight_ue8m0_inplace
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_moe_fp8_layer_for_marlin
)
...
...
@@ -40,6 +42,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
...
...
@@ -393,6 +396,19 @@ class Fp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
del
layer
.
input_scale
# On B200, DeepGemm only support E8M0 scale, which means we need to
# requantize the weight and input to the specific scale
# at the same time.
if
is_blackwell_deep_gemm_used
():
assert
layer
.
weight_block_size
is
not
None
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
if
hasattr
(
layer
,
"weight_scale_inv"
)
else
layer
.
weight_scale
.
data
,
block_sz
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
@@ -670,15 +686,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if
self
.
allow_deep_gemm
:
if
self
.
allow_deep_gemm
and
not
is_blackwell_deep_gemm_used
()
:
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
if
_is_col_major
(
layer
.
w13_weight_scale_inv
):
layer
.
w13_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale_inv
).
contiguous
()
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale_inv
).
contiguous
()
if
_is_col_major
(
layer
.
w2_weight_scale_inv
):
layer
.
w2_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
# If checkpoint is fp16, quantize in place.
elif
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
...
@@ -797,6 +812,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del
layer
.
w13_input_scale
del
layer
.
w2_input_scale
if
is_blackwell_deep_gemm_used
():
assert
layer
.
weight_block_size
is
not
None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz
=
tuple
(
layer
.
weight_block_size
)
requant_weight_ue8m0_inplace
(
layer
.
w13_weight
.
data
,
layer
.
w13_weight_scale_inv
.
data
,
block_sz
,
)
requant_weight_ue8m0_inplace
(
layer
.
w2_weight
.
data
,
layer
.
w2_weight_scale_inv
.
data
,
block_sz
,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if
_is_col_major
(
layer
.
w13_weight_scale_inv
):
layer
.
w13_weight_scale_inv
=
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale_inv
).
contiguous
()
if
_is_col_major
(
layer
.
w2_weight_scale_inv
):
layer
.
w2_weight_scale_inv
=
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
e2de455c
...
...
@@ -5,6 +5,7 @@
import
functools
import
json
import
os
from
collections.abc
import
Sequence
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
...
...
@@ -13,7 +14,7 @@ import vllm.envs as envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_dequantize
)
group_broadcast
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
)
from
vllm.platforms
import
current_platform
...
...
@@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now.
"""
x_dq_block
=
scaled_dequantize
(
x_q_block
,
x_s
)
x_dq_block
=
group_broadcast
(
x_q_block
,
x_s
)
x_q_tensor
,
scale
=
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
return
x_q_tensor
,
scale
...
...
@@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
)
return
C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_tma_aligned_size
(
x
:
int
,
element_size
:
int
)
->
int
:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes
=
16
assert
tma_alignment_bytes
%
element_size
==
0
alignment
=
tma_alignment_bytes
//
element_size
return
cdiv
(
x
,
alignment
)
*
alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert
x
.
dim
()
in
(
2
,
3
)
remove_dim
=
False
m
,
n
=
x
.
shape
[
-
2
],
x
.
shape
[
-
1
]
aligned_m
=
get_tma_aligned_size
(
m
,
x
.
element_size
())
if
x
.
dim
()
==
2
:
if
x
.
stride
(
0
)
==
1
and
x
.
stride
(
1
)
==
aligned_m
:
return
x
x
,
remove_dim
=
x
.
unsqueeze
(
0
),
True
b
=
x
.
shape
[
0
]
# The last kernel gives a column-major TMA aligned layout
if
x
.
stride
(
0
)
==
aligned_m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
aligned_m
:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
def
requant_weight_ue8m0_inplace
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
block_size
:
Sequence
[
int
]
=
(
128
,
128
),
)
->
None
:
"""Re-quantise *weight* so that its per-block scaling factors are in the
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
Args:
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
Expected shape ``(..., M, K)``.
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
with shape ``(..., M // block_size[0], K // block_size[1])``.
block_size: 2-element iterable ``[block_m, block_k]`` describing the
block quantisation granularity.
"""
if
weight
.
numel
()
==
0
:
return
if
weight
.
dtype
!=
torch
.
float8_e4m3fn
:
raise
ValueError
(
"Expected *weight* to be torch.float8_e4m3fn, got "
f
"
{
weight
.
dtype
}
instead."
)
from
vllm.utils.deep_gemm
import
per_block_cast_to_fp8
block_m
,
block_k
=
int
(
block_size
[
0
]),
int
(
block_size
[
1
])
# Flatten leading dimensions so we can iterate over the last two dims.
leading_shape
=
weight
.
shape
[:
-
2
]
if
len
(
leading_shape
)
==
0
:
w_view
=
weight
.
unsqueeze
(
0
)
s_view
=
weight_scale
.
unsqueeze
(
0
)
else
:
w_view
=
weight
.
reshape
(
-
1
,
weight
.
shape
[
-
2
],
weight
.
shape
[
-
1
])
s_view
=
weight_scale
.
reshape
(
-
1
,
*
weight_scale
.
shape
[
-
2
:])
num_mats
=
w_view
.
size
(
0
)
for
idx
in
range
(
num_mats
):
w_q
=
w_view
[
idx
]
s_old
=
s_view
[
idx
]
# De-quantise with the *old* scaling factors (float32).
m_cur
,
k_cur
=
w_q
.
shape
s_float
=
s_old
.
to
(
torch
.
float32
)
# Expand scales along rows and cols by block size, then crop.
s_exp_r
=
torch
.
repeat_interleave
(
s_float
,
block_m
,
dim
=
0
)
s_exp
=
torch
.
repeat_interleave
(
s_exp_r
,
block_k
,
dim
=
1
)
s_exp
=
s_exp
[:
m_cur
,
:
k_cur
]
w_dq
=
w_q
.
to
(
torch
.
float32
)
*
s_exp
# Re-quantise using power-of-two scaling (UE8M0).
w_requant
,
s_requant
=
per_block_cast_to_fp8
(
w_dq
,
[
block_m
,
block_k
])
# Write back the results in-place.
w_q
.
copy_
(
w_requant
)
s_old
.
copy_
(
s_requant
)
vllm/utils/deep_gemm.py
0 → 100644
View file @
e2de455c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import **only** these wrappers.
"""
from
__future__
import
annotations
import
functools
import
importlib
from
typing
import
Any
,
Callable
,
NoReturn
import
torch
import
vllm.envs
as
envs
from
vllm.utils
import
cuda_get_device_properties
,
has_deep_gemm
@
functools
.
cache
def
is_blackwell_deep_gemm_used
()
->
bool
:
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
Blackwell-class GPU.
"""
if
not
(
envs
.
VLLM_USE_DEEP_GEMM
and
has_deep_gemm
()
and
_per_block_cast_impl
is
not
None
):
return
False
return
cuda_get_device_properties
(
0
,
(
"major"
,
))[
0
]
==
10
def
_missing
(
*
_
:
Any
,
**
__
:
Any
)
->
NoReturn
:
"""Placeholder for unavailable DeepGEMM backend."""
raise
RuntimeError
(
"DeepGEMM backend is not available. Please install the `deep_gemm` "
"package to enable FP8 kernels."
)
def
_resolve_symbol
(
module
,
new
:
str
,
old
:
str
)
->
Callable
[...,
Any
]
|
None
:
"""Return the *new* symbol if it exists, otherwise the *old* one."""
if
hasattr
(
module
,
new
):
return
getattr
(
module
,
new
)
if
hasattr
(
module
,
old
):
return
getattr
(
module
,
old
)
return
None
if
not
has_deep_gemm
():
_fp8_gemm_nt_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_masked_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_token_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_block_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
else
:
_dg
=
importlib
.
import_module
(
"deep_gemm"
)
# type: ignore
_fp8_gemm_nt_impl
=
_resolve_symbol
(
_dg
,
"fp8_gemm_nt"
,
"gemm_fp8_fp8_bf16_nt"
,
)
_grouped_impl
=
_resolve_symbol
(
_dg
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"
,
)
_grouped_masked_impl
=
_resolve_symbol
(
_dg
,
"fp8_m_grouped_gemm_nt_masked"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
,
)
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try
:
_math_mod
=
importlib
.
import_module
(
"deep_gemm.utils.math"
)
# type: ignore
_per_token_cast_impl
=
getattr
(
_math_mod
,
"per_token_cast_to_fp8"
,
None
)
_per_block_cast_impl
=
getattr
(
_math_mod
,
"per_block_cast_to_fp8"
,
None
)
except
ModuleNotFoundError
:
_per_token_cast_impl
=
None
_per_block_cast_impl
=
None
def
fp8_gemm_nt
(
*
args
,
**
kwargs
):
if
_fp8_gemm_nt_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_fp8_gemm_nt_impl
(
*
args
,
**
kwargs
)
def
m_grouped_fp8_gemm_nt_contiguous
(
*
args
,
**
kwargs
):
if
_grouped_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_grouped_impl
(
*
args
,
**
kwargs
)
def
fp8_m_grouped_gemm_nt_masked
(
*
args
,
**
kwargs
):
if
_grouped_masked_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_grouped_masked_impl
(
*
args
,
**
kwargs
)
def
per_token_group_cast_to_fp8
(
x
,
group_size
,
*
args
,
**
kwargs
):
"""Wrapper for token-wise FP8 quantisation.
• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
"""
if
_per_token_cast_impl
is
not
None
and
is_blackwell_deep_gemm_used
():
assert
group_size
==
128
,
"group_size must be 128 for deepgemm"
return
_per_token_cast_impl
(
x
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
as
_ptg
)
return
_ptg
(
x
,
group_size
,
*
args
,
**
kwargs
)
def
per_block_cast_to_fp8
(
x
,
*
args
,
**
kwargs
):
if
_per_block_cast_impl
is
not
None
and
is_blackwell_deep_gemm_used
():
return
_per_block_cast_impl
(
x
)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
from
tests.kernels.quant_utils
import
per_block_cast_to_fp8
as
_pbcf
return
_pbcf
(
x
,
*
args
,
**
kwargs
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
"""Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
every element, we compute a cosine-style similarity over the whole tensor
and report ``1 - sim``. Once kernel accuracy improves this helper can be
removed.
"""
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
__all__
=
[
"calc_diff"
,
"fp8_gemm_nt"
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"fp8_m_grouped_gemm_nt_masked"
,
"per_token_group_cast_to_fp8"
,
"per_block_cast_to_fp8"
,
"is_blackwell_deep_gemm_used"
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment