Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
97abeb1d
Unverified
Commit
97abeb1d
authored
Jul 08, 2025
by
Duncan Moss
Committed by
GitHub
Jul 09, 2025
Browse files
[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)
Signed-off-by:
Duncan Moss
<
djm.moss@gmail.com
>
parent
34dad19e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
7 deletions
+5
-7
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+4
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-1
No files found.
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
97abeb1d
...
...
@@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
return
out
.
to
(
dtype
=
out_dtype
)
def
_valid_cutlass_block_scaled_grouped_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
def
_valid_cutlass_block_scaled_grouped_gemm
(
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
)
->
bool
:
def
_valid_cutlass_block_scaled_grouped_gemm_shape
(
M
:
int
,
N
:
int
,
K
:
int
):
return
M
>=
128
and
N
%
128
==
0
and
K
%
128
==
0
def
_valid_cutlass_block_scaled_grouped_gemm_shape
(
N
:
int
,
K
:
int
):
return
N
%
128
==
0
and
K
%
128
==
0
m
=
hidden_states
.
size
(
0
)
_
,
K
,
N
=
w2
.
size
()
if
not
_valid_cutlass_block_scaled_grouped_gemm_shape
(
m
,
N
,
K
):
if
not
_valid_cutlass_block_scaled_grouped_gemm_shape
(
N
,
K
):
logger
.
debug
(
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size."
)
return
False
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
97abeb1d
...
...
@@ -1180,7 +1180,7 @@ def fused_experts(
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
(
allow_cutlass_block_scaled_grouped_gemm
and
use_fp8_w8a8
and
_valid_cutlass_block_scaled_grouped_gemm
(
hidden_states
,
w1
,
w2
)):
and
_valid_cutlass_block_scaled_grouped_gemm
(
w1
,
w2
)):
assert
apply_router_weight_on_input
is
False
return
run_cutlass_block_scaled_fused_experts
(
a
=
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment