Unverified Commit 6ac5e06f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Chore] Clean up pytorch helper functions in `vllm.utils` (#26908)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarisotr0py <2037008807@qq.com>
parent 5c2acb27
......@@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
from vllm.utils import get_tcp_uri
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
......
......@@ -5,7 +5,7 @@ import os
import torch
from vllm.logger import init_logger
from vllm.utils import is_torch_equal
from vllm.utils.torch_utils import is_torch_equal
logger = init_logger(__name__)
......
......@@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None:
def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
......
......@@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
......
......@@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
......
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(
......
......@@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
......
......@@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
......
......@@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):
......
......@@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.utils import cdiv
from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.torch_utils import is_torch_equal_or_newer
@triton.jit
......
......@@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
......
......@@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:
......
......@@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......
......@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
......
......@@ -6,7 +6,10 @@ import torch
from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType
from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
class MambaStateDtypeCalculator:
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig):
......
......@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig):
......
......@@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
......
......@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment