"vscode:/vscode.git/clone" did not exist on "235c9db8a755e0404628a568bf29a492257fe52e"
Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
...@@ -49,10 +49,10 @@ from vllm.platforms import current_platform ...@@ -49,10 +49,10 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import ( from vllm.utils import (
has_triton_kernels, has_triton_kernels,
is_torch_equal_or_newer,
round_up, round_up,
) )
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -45,7 +45,7 @@ try: ...@@ -45,7 +45,7 @@ try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled(): if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip from aiter import gemm_a4w4, per_1x32_f4_quant_hip
......
...@@ -28,13 +28,13 @@ from vllm.model_executor.parameter import ( ...@@ -28,13 +28,13 @@ from vllm.model_executor.parameter import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
) )
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def _quant_dequant_mxfp6( def _quant_dequant_mxfp6(
......
...@@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config ...@@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool: def is_rocm_triton_rotary_embedding_enabled() -> bool:
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
......
...@@ -11,8 +11,8 @@ from vllm.logger import init_logger ...@@ -11,8 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
...@@ -48,6 +48,7 @@ from vllm.model_executor.utils import ( ...@@ -48,6 +48,7 @@ from vllm.model_executor.utils import (
set_weight_attrs, set_weight_attrs,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader ...@@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, get_gguf_extra_tensor_names,
get_gguf_weight_type_map, get_gguf_weight_type_map,
gguf_quant_weights_iterator, gguf_quant_weights_iterator,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
class GGUFModelLoader(BaseModelLoader): class GGUFModelLoader(BaseModelLoader):
......
...@@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import ( ...@@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import (
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_model_architecture, get_model_architecture,
initialize_model, initialize_model,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader ...@@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
import inspect import inspect
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
...@@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available ...@@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def initialize_model( def initialize_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
*, *,
......
...@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING ...@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up from vllm.utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,
......
...@@ -18,7 +18,6 @@ from vllm.config import VllmConfig ...@@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
...@@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo ...@@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import ( from .utils import (
......
...@@ -51,8 +51,8 @@ from vllm.multimodal.processing import ( ...@@ -51,8 +51,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
......
...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import ( ...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import (
Resampler2, Resampler2,
get_2d_sincos_pos_embed, get_2d_sincos_pos_embed,
) )
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -88,6 +87,7 @@ from vllm.platforms import current_platform ...@@ -88,6 +87,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.collections import flatten_2d_lists from vllm.utils.collections import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import ( from .interfaces import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment