Unverified Commit a462331e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Disable moe inplace for torch >= 2.9 (#26497)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 4069db3f
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
...@@ -235,7 +235,11 @@ def fused_marlin_moe( ...@@ -235,7 +235,11 @@ def fused_marlin_moe(
).view(-1, topk, K) ).view(-1, topk, K)
if output is None: if output is None:
output = hidden_states if inplace else torch.empty_like(hidden_states) if inplace and not disable_inplace():
output = hidden_states
else:
output = torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
......
...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
activation_without_mul, activation_without_mul,
disable_inplace,
moe_kernel_quantize_input, moe_kernel_quantize_input,
) )
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
...@@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: ...@@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if inplace: if inplace and not disable_inplace():
return torch_vllm_inplace_fused_experts return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts return torch_vllm_outplace_fused_experts
...@@ -1766,7 +1767,10 @@ def fused_experts_impl( ...@@ -1766,7 +1767,10 @@ def fused_experts_impl(
else: else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if ocp_mx_scheme is not None: if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True # TODO: On platforms for which `current_platform.supports_mx()` is True
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
count_expert_num_tokens, count_expert_num_tokens,
disable_inplace,
) )
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
...@@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1139,7 +1140,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if inplace and self.shared_experts is None: if inplace and self.shared_experts is None and not disable_inplace():
output = hidden_states output = hidden_states
else: else:
output = torch.zeros_like(hidden_states) output = torch.zeros_like(hidden_states)
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.utils.flashinfer import flashinfer_fp4_quantize from vllm.utils.flashinfer import flashinfer_fp4_quantize
...@@ -321,3 +321,10 @@ def _validate_scale_shape( ...@@ -321,3 +321,10 @@ def _validate_scale_shape(
def activation_without_mul(activation: str) -> str: def activation_without_mul(activation: str) -> str:
return activation + "_no_mul" return activation + "_no_mul"
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378
def disable_inplace() -> bool:
return is_torch_equal_or_newer("2.9")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment