Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
......@@ -10,7 +10,8 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
def with_triton_mode(fn):
......
......@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
......
......@@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()
......
......@@ -9,9 +9,9 @@ import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
......
......@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()
......
......@@ -9,9 +9,9 @@ from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
......
......@@ -12,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash,
)
......
......@@ -11,7 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
@contextlib.contextmanager
......
......@@ -20,7 +20,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401
......
......@@ -19,7 +19,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
......
......@@ -27,7 +27,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401
......
......@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import torch
from torch.library import Library
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
......
......@@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):
......
......@@ -5,7 +5,7 @@ import dataclasses
import pytest
from vllm.config import CompilationMode
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import compare_all_settings
......
......@@ -8,7 +8,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version():
......
......@@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
......
......@@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
......
......@@ -15,8 +15,8 @@ from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test
......
......@@ -60,8 +60,8 @@ from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads
from vllm.utils.collections import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads
logger = init_logger(__name__)
......
......@@ -18,7 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment