Unverified Commit 7361ab37 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Remove redundant mutates_args and dispatch_key for direct_register_custom_op (#25512)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 95bc60e4
...@@ -575,9 +575,7 @@ def unified_attention_fake( ...@@ -575,9 +575,7 @@ def unified_attention_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_attention", op_name="unified_attention",
op_func=unified_attention, op_func=unified_attention,
mutates_args=[],
fake_impl=unified_attention_fake, fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe, tags=tag_cudagraph_unsafe,
) )
...@@ -628,6 +626,5 @@ direct_register_custom_op( ...@@ -628,6 +626,5 @@ direct_register_custom_op(
op_func=unified_attention_with_output, op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"], mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe, tags=tag_cudagraph_unsafe,
) )
...@@ -547,7 +547,6 @@ if flashinfer_comm is not None: ...@@ -547,7 +547,6 @@ if flashinfer_comm is not None:
"scale_out", "scale_out",
], ],
fake_impl=call_trtllm_fused_allreduce_norm_fake, fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
flashinfer_trtllm_fused_allreduce_norm = ( flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
......
...@@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm): ...@@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
direct_register_custom_op( direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy", op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl, op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake, fake_impl=all_reduce_symmetric_with_copy_fake,
) )
......
...@@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, ...@@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
if supports_custom_op(): if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op( direct_register_custom_op(
op_name="all_reduce", op_name="all_reduce",
op_func=all_reduce, op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake, fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="reduce_scatter", op_name="reduce_scatter",
op_func=reduce_scatter, op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake, fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="all_gather", op_name="all_gather",
op_func=all_gather, op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake, fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
) )
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -283,7 +282,6 @@ try: ...@@ -283,7 +282,6 @@ try:
op_func=_lora_expand, op_func=_lora_expand,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake, fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
) )
lora_expand = torch.ops.vllm.lora_expand lora_expand = torch.ops.vllm.lora_expand
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -237,7 +236,6 @@ try: ...@@ -237,7 +236,6 @@ try:
op_func=_lora_shrink, op_func=_lora_shrink,
mutates_args=["output_tensor"], mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake, fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
) )
lora_shrink = torch.ops.vllm.lora_shrink lora_shrink = torch.ops.vllm.lora_shrink
......
...@@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake( ...@@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8", op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8, op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
......
...@@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, ...@@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
direct_register_custom_op( direct_register_custom_op(
op_name="fused_marlin_moe", op_name="fused_marlin_moe",
op_func=fused_marlin_moe, op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake, fake_impl=fused_marlin_moe_fake,
) )
...@@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake( ...@@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="outplace_fused_experts", op_name="outplace_fused_experts",
op_func=outplace_fused_experts, op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake, fake_impl=outplace_fused_experts_fake,
tags=(() if is_torch_equal_or_newer("2.7.0") else tags=(() if is_torch_equal_or_newer("2.7.0") else
(torch.Tag.needs_fixed_stride_order, )), (torch.Tag.needs_fixed_stride_order, )),
......
...@@ -2040,7 +2040,6 @@ direct_register_custom_op( ...@@ -2040,7 +2040,6 @@ direct_register_custom_op(
op_func=moe_forward, op_func=moe_forward,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=moe_forward_fake, fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
...@@ -2071,7 +2070,6 @@ direct_register_custom_op( ...@@ -2071,7 +2070,6 @@ direct_register_custom_op(
op_func=moe_forward_shared, op_func=moe_forward_shared,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake, fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
......
...@@ -223,17 +223,13 @@ if current_platform.is_rocm(): ...@@ -223,17 +223,13 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1", op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl, op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake, fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_fused_moe", op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl, op_func=rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_moe_fake, fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
...@@ -241,7 +237,6 @@ if current_platform.is_rocm(): ...@@ -241,7 +237,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_topk_softmax_impl, op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake, fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
...@@ -249,7 +244,6 @@ if current_platform.is_rocm(): ...@@ -249,7 +244,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_biased_grouped_topk_impl, op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"], mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake, fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
...@@ -257,7 +251,6 @@ if current_platform.is_rocm(): ...@@ -257,7 +251,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_grouped_topk_impl, op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"], mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake, fake_impl=rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
) )
......
...@@ -103,17 +103,13 @@ if current_platform.is_rocm(): ...@@ -103,17 +103,13 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rms_norm", op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl, op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake, fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
) )
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
...@@ -401,5 +400,4 @@ direct_register_custom_op( ...@@ -401,5 +400,4 @@ direct_register_custom_op(
op_func=linear_attention, op_func=linear_attention,
mutates_args=["output"], mutates_args=["output"],
fake_impl=linear_attention_fake, fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
) )
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
...@@ -464,5 +463,4 @@ direct_register_custom_op( ...@@ -464,5 +463,4 @@ direct_register_custom_op(
op_func=mamba_mixer, op_func=mamba_mixer,
mutates_args=["output"], mutates_args=["output"],
fake_impl=mamba_mixer_fake, fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
) )
...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader) LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -765,5 +764,4 @@ direct_register_custom_op( ...@@ -765,5 +764,4 @@ direct_register_custom_op(
op_func=mamba_mixer2, op_func=mamba_mixer2,
mutates_args=["output"], mutates_args=["output"],
fake_impl=mamba_mixer2_fake, fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
) )
...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator) MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ( from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata) ShortConvAttentionMetadata)
...@@ -251,5 +250,4 @@ direct_register_custom_op( ...@@ -251,5 +250,4 @@ direct_register_custom_op(
op_func=short_conv, op_func=short_conv,
mutates_args=["output"], mutates_args=["output"],
fake_impl=short_conv_fake, fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
) )
...@@ -4,7 +4,6 @@ import logging ...@@ -4,7 +4,6 @@ import logging
import torch import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt from vllm.utils.deep_gemm import fp8_gemm_nt
...@@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake( ...@@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm", op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm, op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_deepgemm_block_scaled_mm_fake, fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
) )
...@@ -161,7 +161,6 @@ try: ...@@ -161,7 +161,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_fused_mul_mat_gguf", op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf, op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake, fake_impl=_fused_mul_mat_gguf_fake,
) )
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
...@@ -273,7 +272,6 @@ try: ...@@ -273,7 +272,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_fused_moe_gguf", op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf, op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake, fake_impl=_fused_moe_gguf_fake,
) )
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
...@@ -319,7 +317,6 @@ try: ...@@ -319,7 +317,6 @@ try:
direct_register_custom_op( direct_register_custom_op(
op_name="_apply_gguf_embedding", op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding, op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake, fake_impl=_apply_gguf_embedding_fake,
) )
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
......
...@@ -51,9 +51,7 @@ if current_platform.is_rocm(): ...@@ -51,9 +51,7 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8", op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl, op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake, fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
) )
......
...@@ -91,9 +91,7 @@ if current_platform.is_rocm(): ...@@ -91,9 +91,7 @@ if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale", op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl, op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
) )
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()): and current_platform.is_fp8_fnuz()):
...@@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake( ...@@ -135,7 +133,6 @@ def _w8a8_triton_block_scaled_mm_fake(
direct_register_custom_op( direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func", "w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func, _w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake, fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA", dispatch_key="CUDA",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment