Unverified Commit 4ebc9108 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Kernel] Centralize platform kernel import in `current_platform.import_kernels` (#26286)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent e1ba2356
......@@ -12,8 +12,7 @@ from vllm.scalar_type import ScalarType
logger = init_logger(__name__)
current_platform.import_core_kernels()
supports_moe_ops = current_platform.try_import_moe_kernels()
current_platform.import_kernels()
if TYPE_CHECKING:
......@@ -1921,7 +1920,7 @@ def moe_wna16_marlin_gemm(
)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(
......
......@@ -170,22 +170,15 @@ class Platform:
return device_id
@classmethod
def import_core_kernels(cls) -> None:
def import_kernels(cls) -> None:
"""Import any platform-specific C kernels."""
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def try_import_moe_kernels(cls) -> bool:
"""Import any platform-specific MoE kernels."""
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
return True
return False
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
......@@ -45,8 +46,10 @@ class TpuPlatform(Platform):
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
@classmethod
def import_core_kernels(cls) -> None:
pass
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
from typing import TYPE_CHECKING, Optional
......@@ -35,8 +36,10 @@ class XPUPlatform(Platform):
device_control_env_var: str = "ZE_AFFINITY_MASK"
@classmethod
def import_core_kernels(cls) -> None:
pass
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment