Unverified Commit d6e634f3 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Suppress import custom_ops warning (#7458)

parent 4d2dc507
......@@ -6,13 +6,15 @@ import torch
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if not current_platform.is_tpu():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
# ruff: noqa: F401
......
......@@ -29,7 +29,6 @@ import torch.types
from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__)
......@@ -359,6 +358,7 @@ def is_xpu() -> bool:
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
from vllm import _custom_ops as ops
max_shared_mem = (
ops.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment