Unverified Commit 3ae082c3 authored by dongbo910220's avatar dongbo910220 Committed by GitHub
Browse files

[Chore] Separate out optional dependency checks from vllm.utils (#27207)


Signed-off-by: default avatardongbo910220 <1275604947@qq.com>
Signed-off-by: default avatardongbo910220 <32610838+dongbo910220@users.noreply.github.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 49c00fe3
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm from vllm.utils import cdiv
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
_ceil_to_ue8m0, _ceil_to_ue8m0,
calc_diff, calc_diff,
...@@ -15,6 +15,7 @@ from vllm.utils.deep_gemm import ( ...@@ -15,6 +15,7 @@ from vllm.utils.deep_gemm import (
get_num_sms, get_num_sms,
get_paged_mqa_logits_metadata, get_paged_mqa_logits_metadata,
) )
from vllm.utils.import_utils import has_deep_gemm
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import ( from .mk_objects import (
TestMoEQuantConfig, TestMoEQuantConfig,
......
...@@ -35,9 +35,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -35,9 +35,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported, cutlass_fp8_supported,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
@dataclass @dataclass
......
...@@ -15,7 +15,7 @@ from torch.distributed import ProcessGroup ...@@ -15,7 +15,7 @@ from torch.distributed import ProcessGroup
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from vllm.utils import has_deep_ep from vllm.utils.import_utils import has_deep_ep
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
if has_deep_ep(): if has_deep_ep():
......
...@@ -21,11 +21,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -21,11 +21,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe, modular_triton_fused_moe,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
) )
from vllm.utils.import_utils import has_deep_gemm
dg_available = has_deep_gemm() dg_available = has_deep_gemm()
......
...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -21,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep from vllm.utils.import_utils import has_deep_ep
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels(): if not has_triton_kernels():
pytest.skip( pytest.skip(
......
...@@ -13,8 +13,8 @@ import torch ...@@ -13,8 +13,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
......
...@@ -18,12 +18,12 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -18,12 +18,12 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm, w8a8_triton_block_scaled_mm,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8, per_block_cast_to_fp8,
) )
from vllm.utils.import_utils import has_deep_gemm
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
......
...@@ -12,7 +12,6 @@ from tests.quantization.utils import is_quant_method_supported ...@@ -12,7 +12,6 @@ from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
from ..utils import check_logprobs_close from ..utils import check_logprobs_close
......
...@@ -9,8 +9,8 @@ import vllm.envs as envs ...@@ -9,8 +9,8 @@ import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache from .base_device_communicator import All2AllManagerBase, Cache
......
...@@ -14,8 +14,9 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ...@@ -14,8 +14,9 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_Scheme, OCP_MX_Scheme,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.utils import cdiv, has_triton_kernels from vllm.utils import cdiv
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -26,12 +26,12 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache ...@@ -26,12 +26,12 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nt_contiguous,
) )
from vllm.utils.func_utils import run_once from vllm.utils.func_utils import run_once
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -55,8 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -55,8 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils import cdiv, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
......
...@@ -93,7 +93,6 @@ from vllm.model_executor.parameter import ( ...@@ -93,7 +93,6 @@ from vllm.model_executor.parameter import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
...@@ -102,6 +101,7 @@ from vllm.utils.deep_gemm import ( ...@@ -102,6 +101,7 @@ from vllm.utils.deep_gemm import (
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
) )
from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
......
...@@ -48,11 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s ...@@ -48,11 +48,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_s
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import ( from vllm.utils import round_up
has_triton_kernels,
round_up,
)
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -5,7 +5,6 @@ import contextlib ...@@ -5,7 +5,6 @@ import contextlib
import datetime import datetime
import enum import enum
import getpass import getpass
import importlib
import inspect import inspect
import json import json
import multiprocessing import multiprocessing
...@@ -1062,46 +1061,6 @@ def check_use_alibi(model_config: ModelConfig) -> bool: ...@@ -1062,46 +1061,6 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
) )
@cache
def _has_module(module_name: str) -> bool:
"""Return True if *module_name* can be found in the current environment.
The result is cached so that subsequent queries for the same module incur
no additional overhead.
"""
return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep")
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return _has_module("deep_gemm")
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
return _has_module("triton_kernels")
def has_tilelang() -> bool:
"""Whether the optional `tilelang` package is available."""
return _has_module("tilelang")
def set_process_title( def set_process_title(
name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
) -> None: ) -> None:
......
...@@ -16,7 +16,8 @@ import torch ...@@ -16,7 +16,8 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import logger from vllm.logger import logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm from vllm.utils import cdiv
from vllm.utils.import_utils import has_deep_gemm
@functools.cache @functools.cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment