Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
...@@ -13,7 +13,7 @@ from vllm import LLM ...@@ -13,7 +13,7 @@ from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
......
...@@ -10,7 +10,8 @@ import torch ...@@ -10,7 +10,8 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata( def get_aiter_mla_metadata(
......
...@@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import ( ...@@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import (
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .caching import VllmSerializableFunction from .caching import VllmSerializableFunction
from .compiler_interface import ( from .compiler_interface import (
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
......
...@@ -16,7 +16,7 @@ import torch.fx as fx ...@@ -16,7 +16,7 @@ import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
class CompilerInterface: class CompilerInterface:
......
...@@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo ...@@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher ...@@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass from torch._inductor.custom_graph_pass import CustomGraphPass
......
...@@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass ...@@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
...@@ -41,8 +41,9 @@ from vllm.transformers_utils.config import ( ...@@ -41,8 +41,9 @@ from vllm.transformers_utils.config import (
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType, common_broadcastable_dtype from vllm.utils import LayerBlockType
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
......
...@@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list from vllm.utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
......
...@@ -22,7 +22,8 @@ from vllm.logger import init_logger ...@@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.utils import cuda_device_count_stateless, update_environment_variables from vllm.utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import ( ...@@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
try: try:
ops.meta_size() ops.meta_size()
......
...@@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ...@@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
) )
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import current_stream from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): ...@@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context, nccl_symm_mem_context,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED: if _NCCL_SYMM_OPS_REGISTERED:
......
...@@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config ...@@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ...@@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import current_stream from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ...@@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool, TensorMemoryPool,
) )
from vllm.utils import current_stream, get_ip from vllm.utils import get_ip
from vllm.utils.torch_utils import current_stream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ...@@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import ( from vllm.utils import (
direct_register_custom_op,
get_distributed_init_method, get_distributed_init_method,
supports_custom_op,
) )
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import (
direct_register_custom_op,
supports_custom_op,
)
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment