Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a462331e
Unverified
Commit
a462331e
authored
Oct 09, 2025
by
bnellnm
Committed by
GitHub
Oct 09, 2025
Browse files
[Bugfix] Disable moe inplace for torch >= 2.9 (#26497)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
4069db3f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
6 deletions
+22
-6
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+6
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+6
-2
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+2
-1
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+8
-1
No files found.
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
a462331e
...
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
,
disable_inplace
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_make_workspace_new
,
marlin_moe_intermediate_size
,
...
...
@@ -235,7 +235,11 @@ def fused_marlin_moe(
).
view
(
-
1
,
topk
,
K
)
if
output
is
None
:
output
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
if
inplace
and
not
disable_inplace
():
output
=
hidden_states
else
:
output
=
torch
.
empty_like
(
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
-
1
,
topk
,
K
),
dim
=
1
,
out
=
output
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
a462331e
...
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
activation_without_mul
,
disable_inplace
,
moe_kernel_quantize_input
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
dequant_mxfp4
...
...
@@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
inplace
:
if
inplace
and
not
disable_inplace
()
:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_outplace_fused_experts
...
...
@@ -1766,7 +1767,10 @@ def fused_experts_impl(
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
out_hidden_states
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
if
inplace
and
not
disable_inplace
():
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
ocp_mx_scheme
is
not
None
:
# TODO: On platforms for which `current_platform.supports_mx()` is True
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
a462331e
...
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
count_expert_num_tokens
,
disable_inplace
,
)
from
vllm.utils
import
cdiv
from
vllm.v1.worker.ubatching
import
(
...
...
@@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if
inplace
and
self
.
shared_experts
is
None
:
if
inplace
and
self
.
shared_experts
is
None
and
not
disable_inplace
()
:
output
=
hidden_states
else
:
output
=
torch
.
zeros_like
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
a462331e
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
,
is_torch_equal_or_newer
from
vllm.utils.flashinfer
import
flashinfer_fp4_quantize
...
...
@@ -321,3 +321,10 @@ def _validate_scale_shape(
def
activation_without_mul
(
activation
:
str
)
->
str
:
return
activation
+
"_no_mul"
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
def
disable_inplace
()
->
bool
:
return
is_torch_equal_or_newer
(
"2.9"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment