Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
......@@ -49,10 +49,10 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (
has_triton_kernels,
is_torch_equal_or_newer,
round_up,
)
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
......
......@@ -45,7 +45,7 @@ try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
......
......@@ -28,13 +28,13 @@ from vllm.model_executor.parameter import (
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
......
......@@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
......
......@@ -3,7 +3,7 @@
import torch
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def _quant_dequant_mxfp6(
......
......@@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
......
......@@ -10,7 +10,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
......
......@@ -5,7 +5,7 @@ import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
......
......@@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
......
......@@ -11,8 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
......
......@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
......@@ -48,6 +48,7 @@ from vllm.model_executor.utils import (
set_weight_attrs,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
......
......@@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
from vllm.utils.torch_utils import set_default_torch_dtype
class GGUFModelLoader(BaseModelLoader):
......
......@@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import (
from vllm.model_executor.model_loader.utils import (
get_model_architecture,
initialize_model,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
......
......@@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import contextlib
import inspect
import warnings
from contextlib import contextmanager
......@@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def initialize_model(
vllm_config: VllmConfig,
*,
......
......@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
from vllm.utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING:
......
......@@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata,
......
......@@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
......@@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (
......
......@@ -51,8 +51,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
from .interfaces import (
MultiModalEmbeddings,
......
......@@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import (
Resampler2,
get_2d_sincos_pos_embed,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
......@@ -88,6 +87,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.collections import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment