Unverified Commit 4197168e authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[ez] Remove checks for torch version <= 2.8 (#33209)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
parent 59bcc5b6
...@@ -8,7 +8,7 @@ from torch._ops import OpOverload ...@@ -8,7 +8,7 @@ from torch._ops import OpOverload
import vllm.envs as envs import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer, rocm_aiter_sparse_attn_indexer,
rocm_aiter_sparse_attn_indexer_fake, rocm_aiter_sparse_attn_indexer_fake,
...@@ -1015,12 +1015,6 @@ class rocm_aiter_ops: ...@@ -1015,12 +1015,6 @@ class rocm_aiter_ops:
def register_ops_once() -> None: def register_ops_once() -> None:
global _OPS_REGISTERED global _OPS_REGISTERED
if not _OPS_REGISTERED: if not _OPS_REGISTERED:
tags = (
tuple()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
)
# register all the custom ops here # register all the custom ops here
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1", op_name="rocm_aiter_asm_moe_tkw1",
...@@ -1075,7 +1069,6 @@ class rocm_aiter_ops: ...@@ -1075,7 +1069,6 @@ class rocm_aiter_ops:
op_func=_rocm_aiter_mla_decode_fwd_impl, op_func=_rocm_aiter_mla_decode_fwd_impl,
mutates_args=["o"], mutates_args=["o"],
fake_impl=_rocm_aiter_mla_decode_fwd_fake, fake_impl=_rocm_aiter_mla_decode_fwd_fake,
tags=tags,
) )
direct_register_custom_op( direct_register_custom_op(
......
...@@ -33,7 +33,6 @@ from vllm.logger import init_logger ...@@ -33,7 +33,6 @@ from vllm.logger import init_logger
from vllm.logging_utils import lazy from vllm.logging_utils import lazy
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .compiler_interface import ( from .compiler_interface import (
CompilerInterface, CompilerInterface,
...@@ -94,10 +93,8 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: ...@@ -94,10 +93,8 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.backend == "inductor": if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough, # Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build. # and the symbol actually exists in this PyTorch build.
if ( if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
envs.VLLM_USE_STANDALONE_COMPILE torch._inductor, "standalone_compile"
and is_torch_equal_or_newer("2.8.0.dev")
and hasattr(torch._inductor, "standalone_compile")
): ):
logger.debug("Using InductorStandaloneAdaptor") logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor( return InductorStandaloneAdaptor(
......
...@@ -501,20 +501,19 @@ class InductorAdaptor(CompilerInterface): ...@@ -501,20 +501,19 @@ class InductorAdaptor(CompilerInterface):
# get hit. # get hit.
# TODO(zou3519): we're going to replace this all with # TODO(zou3519): we're going to replace this all with
# standalone_compile sometime. # standalone_compile sometime.
if is_torch_equal_or_newer("2.6"): stack.enter_context(
stack.enter_context( torch._inductor.config.patch(fx_graph_remote_cache=False)
torch._inductor.config.patch(fx_graph_remote_cache=False) )
) # InductorAdaptor (unfortunately) requires AOTAutogradCache
# InductorAdaptor (unfortunately) requires AOTAutogradCache # to be turned off to run. It will fail to acquire the hash_str
# to be turned off to run. It will fail to acquire the hash_str # and error if not.
# and error if not. # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. stack.enter_context(
stack.enter_context( torch._functorch.config.patch(enable_autograd_cache=False)
torch._functorch.config.patch(enable_autograd_cache=False) )
) stack.enter_context(
stack.enter_context( torch._functorch.config.patch(enable_remote_autograd_cache=False)
torch._functorch.config.patch(enable_remote_autograd_cache=False) )
)
compiled_graph = compile_fx( compiled_graph = compile_fx(
graph, graph,
......
...@@ -7,12 +7,11 @@ import inspect ...@@ -7,12 +7,11 @@ import inspect
import os import os
import sys import sys
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload from typing import TYPE_CHECKING, Any, TypeVar, overload
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging import version
from torch._dynamo.symbolic_convert import InliningInstructionTranslator from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs import vllm.envs as envs
...@@ -540,7 +539,6 @@ def _support_torch_compile( ...@@ -540,7 +539,6 @@ def _support_torch_compile(
torch._dynamo.config.patch(**dynamo_config_patches), torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config), maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches), torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches), torch._inductor.config.patch(**inductor_config_patches),
): ):
use_aot_compile = envs.VLLM_USE_AOT_COMPILE use_aot_compile = envs.VLLM_USE_AOT_COMPILE
...@@ -647,42 +645,3 @@ def maybe_use_cudagraph_partition_wrapper( ...@@ -647,42 +645,3 @@ def maybe_use_cudagraph_partition_wrapper(
and compilation_config.use_inductor_graph_partition and compilation_config.use_inductor_graph_partition
): ):
torch._inductor.utils.set_customized_partition_wrappers(None) torch._inductor.utils.set_customized_partition_wrappers(None)
@contextlib.contextmanager
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
`BasevLLMParameters` without having to replace them with regular tensors
before `torch.compile`-time.
"""
from vllm.model_executor.parameter import (
BasevLLMParameter,
ModelWeightParameter,
RowvLLMParameter,
_ColumnvLLMParameter,
)
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
return False
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
yield
return
with (
torch._dynamo.config.patch(
"traceable_tensor_subclasses",
[
BasevLLMParameter,
ModelWeightParameter,
_ColumnvLLMParameter,
RowvLLMParameter,
],
),
patch(
"torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
),
):
yield
...@@ -16,18 +16,10 @@ import torch ...@@ -16,18 +16,10 @@ import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.utils import Range from vllm.config.utils import Range
if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import (
Torch25CustomGraphPass as CustomGraphPass,
)
_pass_context = None _pass_context = None
P = ParamSpec("P") P = ParamSpec("P")
......
...@@ -777,10 +777,9 @@ class CompilationConfig: ...@@ -777,10 +777,9 @@ class CompilationConfig:
# and it is not yet a priority. RFC here: # and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703 # https://github.com/vllm-project/vllm/issues/14703
if is_torch_equal_or_newer("2.6"): KEY = "enable_auto_functionalized_v2"
KEY = "enable_auto_functionalized_v2" if KEY not in self.inductor_compile_config:
if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False
self.inductor_compile_config[KEY] = False
for k, v in self.inductor_passes.items(): for k, v in self.inductor_passes.items():
if not isinstance(v, str): if not isinstance(v, str):
......
...@@ -31,7 +31,6 @@ import vllm.envs as envs ...@@ -31,7 +31,6 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import suppress_stdout from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -429,20 +428,11 @@ def init_gloo_process_group( ...@@ -429,20 +428,11 @@ def init_gloo_process_group(
different torch versions. different torch versions.
""" """
with suppress_stdout(): with suppress_stdout():
if is_torch_equal_or_newer("2.6"): pg = ProcessGroup(
pg = ProcessGroup( prefix_store,
prefix_store, group_rank,
group_rank, group_size,
group_size, )
)
else:
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo( backend_class = ProcessGroupGloo(
...@@ -450,9 +440,7 @@ def init_gloo_process_group( ...@@ -450,9 +440,7 @@ def init_gloo_process_group(
) )
backend_type = ProcessGroup.BackendType.GLOO backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu") device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"): pg._set_default_backend(backend_type)
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group() backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class) pg._register_backend(device, backend_type, backend_class)
...@@ -534,12 +522,5 @@ def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: ...@@ -534,12 +522,5 @@ def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
Destroy ProcessGroup returned by Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group(). stateless_init_torch_distributed_process_group().
""" """
if is_torch_equal_or_newer("2.7"): pg.shutdown()
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
_unregister_process_group(pg.group_name) _unregister_process_group(pg.group_name)
...@@ -52,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -52,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1406,11 +1406,6 @@ direct_register_custom_op( ...@@ -1406,11 +1406,6 @@ direct_register_custom_op(
op_func=inplace_fused_experts, op_func=inplace_fused_experts,
mutates_args=["hidden_states"], mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake, fake_impl=inplace_fused_experts_fake,
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
) )
...@@ -1501,11 +1496,6 @@ direct_register_custom_op( ...@@ -1501,11 +1496,6 @@ direct_register_custom_op(
op_name="outplace_fused_experts", op_name="outplace_fused_experts",
op_func=outplace_fused_experts, op_func=outplace_fused_experts,
fake_impl=outplace_fused_experts_fake, fake_impl=outplace_fused_experts_fake,
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
) )
......
...@@ -56,7 +56,6 @@ from vllm.scalar_type import scalar_types ...@@ -56,7 +56,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -89,7 +88,6 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: ...@@ -89,7 +88,6 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
# If FlashInfer is not available, try either Marlin or Triton # If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = ( triton_kernels_supported = (
has_triton_kernels() has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100 # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
...@@ -151,7 +149,6 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ...@@ -151,7 +149,6 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# If FlashInfer is not available, try either Marlin or Triton # If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = ( triton_kernels_supported = (
has_triton_kernels() has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100 # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
......
...@@ -108,20 +108,6 @@ class TorchAOConfig(QuantizationConfig): ...@@ -108,20 +108,6 @@ class TorchAOConfig(QuantizationConfig):
skip_modules: list[str] | None = None, skip_modules: list[str] | None = None,
is_checkpoint_torchao_serialized: bool = False, is_checkpoint_torchao_serialized: bool = False,
) -> None: ) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0.dev"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
super().__init__() super().__init__()
self.torchao_config = torchao_config self.torchao_config = torchao_config
self.skip_modules = skip_modules or [] self.skip_modules = skip_modules or []
......
...@@ -709,9 +709,7 @@ def is_torch_equal(target: str) -> bool: ...@@ -709,9 +709,7 @@ def is_torch_equal(target: str) -> bool:
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
def supports_xccl() -> bool: def supports_xccl() -> bool:
return ( return torch.distributed.is_xccl_available()
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
)
# create a library to hold the custom op # create a library to hold the custom op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment