Unverified Commit 9dbcce84 authored by xendo's avatar xendo Committed by GitHub
Browse files

[Neuron] [Bugfix] Fix neuron startup (#9374)


Co-authored-by: default avatarJerzy Zagorski <jzagorsk@amazon.com>
parent a48e3ec0
...@@ -26,7 +26,8 @@ with contextlib.suppress(ImportError): ...@@ -26,7 +26,8 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
supports_moe_ops = True supports_moe_ops = True
if TYPE_CHECKING: # neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
def register_fake(fn): def register_fake(fn):
return lambda name: fn return lambda name: fn
......
...@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, ...@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu, is_hip, is_openvino, is_xpu, print_warning_once)
print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -215,8 +214,10 @@ class ModelConfig: ...@@ -215,8 +214,10 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free() self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state() self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron( if current_platform.is_neuron():
) else None self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task, self.hf_config) supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
...@@ -368,7 +369,7 @@ class ModelConfig: ...@@ -368,7 +369,7 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.") " is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True envs.VLLM_USE_TRITON_AWQ = True
if is_neuron( if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization: ) and self.quantization not in neuron_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
...@@ -1112,7 +1113,7 @@ class DeviceConfig: ...@@ -1112,7 +1113,7 @@ class DeviceConfig:
# Automated device type detection # Automated device type detection
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
self.device_type = "cuda" self.device_type = "cuda"
elif is_neuron(): elif current_platform.is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
......
...@@ -58,6 +58,13 @@ try: ...@@ -58,6 +58,13 @@ try:
except Exception: except Exception:
pass pass
is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
if is_tpu: if is_tpu:
# people might install pytorch built with cuda but run on tpu # people might install pytorch built with cuda but run on tpu
# so we need to check tpu first # so we need to check tpu first
...@@ -75,6 +82,9 @@ elif is_xpu: ...@@ -75,6 +82,9 @@ elif is_xpu:
elif is_cpu: elif is_cpu:
from .cpu import CpuPlatform from .cpu import CpuPlatform
current_platform = CpuPlatform() current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
else: else:
current_platform = UnspecifiedPlatform() current_platform = UnspecifiedPlatform()
......
...@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum): ...@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto() TPU = enum.auto()
XPU = enum.auto() XPU = enum.auto()
CPU = enum.auto() CPU = enum.auto()
NEURON = enum.auto()
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
...@@ -48,6 +49,9 @@ class Platform: ...@@ -48,6 +49,9 @@ class Platform:
def is_cpu(self) -> bool: def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`.""" """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
......
from .interface import Platform, PlatformEnum
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"
from importlib.util import find_spec from importlib.util import find_spec
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
HAS_TRITON = find_spec("triton") is not None # neuron has too old torch
HAS_TRITON = find_spec(
"triton") is not None and not current_platform.is_neuron()
if not HAS_TRITON: if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions" logger.info("Triton not installed; certain GPU-related functions"
......
...@@ -327,15 +327,6 @@ def is_openvino() -> bool: ...@@ -327,15 +327,6 @@ def is_openvino() -> bool:
return False return False
@lru_cache(maxsize=None)
def is_neuron() -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def is_xpu() -> bool: def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
...@@ -786,7 +777,7 @@ def is_pin_memory_available() -> bool: ...@@ -786,7 +777,7 @@ def is_pin_memory_available() -> bool:
elif is_xpu(): elif is_xpu():
print_warning_once("Pin memory is not supported on XPU.") print_warning_once("Pin memory is not supported on XPU.")
return False return False
elif is_neuron(): elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.") print_warning_once("Pin memory is not supported on Neuron.")
return False return False
elif current_platform.is_cpu() or is_openvino(): elif current_platform.is_cpu() or is_openvino():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment