Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
...@@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous ...@@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer from vllm.utils import get_tcp_uri
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_torch_equal from vllm.utils.torch_utils import is_torch_equal
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None: ...@@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None:
def use_aot_compile() -> bool: def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit @triton.jit
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit @triton.jit
......
...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8( def flashinfer_fused_moe_blockscale_fp8(
......
...@@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc ...@@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
......
...@@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
......
...@@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
......
...@@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.utils import cdiv
from vllm.utils.flashinfer import flashinfer_fp4_quantize from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.torch_utils import is_torch_equal_or_newer
@triton.jit @triton.jit
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:
......
...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update, selective_state_update,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......
...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
......
...@@ -6,7 +6,10 @@ import torch ...@@ -6,7 +6,10 @@ import torch
from vllm.config.cache import MambaDType from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.distributed import divide from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
class MambaStateDtypeCalculator: class MambaStateDtypeCalculator:
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods, QuantizationMethods,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig): class BitsAndBytesConfig(QuantizationConfig):
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig): class FPQuantConfig(QuantizationConfig):
......
...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment