Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5923ab95
Unverified
Commit
5923ab95
authored
Jul 10, 2025
by
Duncan Moss
Committed by
GitHub
Jul 11, 2025
Browse files
[fix]: disable cutlass block scaled group gemm for EP (#20781)
Signed-off-by:
Duncan Moss
<
djm.moss@gmail.com
>
parent
0cf893ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
9 deletions
+34
-9
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
...ation/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
+4
-5
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+27
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-2
No files found.
csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu
View file @
5923ab95
...
...
@@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
reinterpret_cast
<
typename
ScheduleConfig
::
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a_ptrs
.
get_device
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
device_id
=
a_ptrs
.
device
().
index
();
static
const
cutlass
::
KernelHardwareInfo
hw_info
{
device_id
,
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
device_id
)};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
5923ab95
...
...
@@ -553,8 +553,10 @@ def cutlass_moe_fp4(a: torch.Tensor,
return
out
.
to
(
dtype
=
out_dtype
)
def
_valid_cutlass_block_scaled_grouped_gemm
(
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
)
->
bool
:
def
_valid_cutlass_block_scaled_grouped_gemm
(
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
inplace
:
bool
,
activation
:
str
,
apply_router_weight_on_input
:
bool
,
expert_map
:
Optional
[
torch
.
Tensor
])
->
bool
:
def
_valid_cutlass_block_scaled_grouped_gemm_shape
(
N
:
int
,
K
:
int
):
return
N
%
128
==
0
and
K
%
128
==
0
...
...
@@ -570,6 +572,29 @@ def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s)."
)
return
False
if
expert_map
is
not
None
:
logger
.
debug
(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
" not supported."
)
return
False
if
activation
!=
"silu"
:
logger
.
debug
(
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
" supported."
)
return
False
if
apply_router_weight_on_input
:
logger
.
debug
(
"CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported."
)
return
False
if
inplace
:
logger
.
debug
(
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
)
return
False
return
True
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
5923ab95
...
...
@@ -1192,8 +1192,9 @@ def fused_experts(
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
(
allow_cutlass_block_scaled_grouped_gemm
and
use_fp8_w8a8
and
_valid_cutlass_block_scaled_grouped_gemm
(
w1
,
w2
)):
assert
apply_router_weight_on_input
is
False
and
_valid_cutlass_block_scaled_grouped_gemm
(
w1
,
w2
,
inplace
,
activation
,
apply_router_weight_on_input
,
expert_map
)):
return
run_cutlass_block_scaled_fused_experts
(
a
=
hidden_states
,
w1
=
w1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment