Unverified Commit 4172235a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 deprecation] Deprecate V0 Neuron backend (#21159)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 848562bd
......@@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]:
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
def neuron_platform_plugin() -> Optional[str]:
tnx_installed = False
nxd_installed = False
logger.debug("Checking if Neuron platform is available.")
try:
import transformers_neuronx # noqa: F401
tnx_installed = True
logger.debug("Confirmed Neuron platform is available because"
" transformers_neuronx is found.")
except ImportError:
pass
try:
import neuronx_distributed_inference # noqa: F401
nxd_installed = True
logger.debug("Confirmed Neuron platform is available because"
" neuronx_distributed_inference is found.")
except ImportError:
pass
is_neuron = tnx_installed or nxd_installed
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
builtin_platform_plugins = {
'tpu': tpu_platform_plugin,
'cuda': cuda_platform_plugin,
'rocm': rocm_platform_plugin,
'xpu': xpu_platform_plugin,
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
}
......
......@@ -73,7 +73,6 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto()
......@@ -164,9 +163,6 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class NeuronFramework(enum.Enum):
TRANSFORMERS_NEURONX = "transformers-neuronx"
NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference"
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
dist_backend: str = "gloo"
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
if envs.VLLM_USE_V1:
return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa
else:
return Platform.get_device_communicator_cls()
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
@lru_cache
def is_neuronx_distributed_inference(cls) -> bool:
try:
import neuronx_distributed_inference
except ImportError:
neuronx_distributed_inference = None
return neuronx_distributed_inference is not None
@classmethod
@lru_cache
def is_transformers_neuronx(cls) -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
def get_neuron_framework_to_use(self):
"""Return the specified framework if corresponding installations are
available.
If no framework is specified, use neuronx-distributed-inference by
default.
If that's unavailable, check and switch to transformers-neuronx.
"""
if not self.is_neuron():
raise AssertionError(
f"Neuron Framework unavailable for platform: {self}")
tnx_installed = self.is_transformers_neuronx()
nxd_installed = self.is_neuronx_distributed_inference()
specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK")
tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value
if specified_framework == tnx_framework and tnx_installed:
return self.TRANSFORMERS_NEURONX
if ((specified_framework == nxd_framework and nxd_installed)
or (specified_framework is None and nxd_installed)):
return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
if specified_framework is None and tnx_installed:
return NeuronFramework.TRANSFORMERS_NEURONX
return None
def use_neuronx_distributed(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
is used to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
return self.get_neuron_framework_to_use() == nxd_framework
def use_transformers_neuronx(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
return self.get_neuron_framework_to_use(
) == NeuronFramework.TRANSFORMERS_NEURONX
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A Neuron worker class."""
import os
from typing import List, Optional, Set, Tuple
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class NeuronWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
model_runner: NeuronModelRunner
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
self.lora_config = vllm_config.lora_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
neuron_framework = current_platform.get_neuron_framework_to_use()
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
self.model_runner = self.get_tnx_model_runner(vllm_config)
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
self.model_runner = self.get_neuronx_distributed_model_runner(
vllm_config)
else:
raise NotImplementedError(
"Specified framework" +
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
" is either not installed or not supported." +
" Supported frameworks: " +
"[transformers-neuronx, neuronx-distributed-inference]")
def get_tnx_model_runner(self, vllm_config):
assert (self.lora_config
is None), ("LoRA is not supported for TransformersNeuronX "
"framework.")
if self.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not supported for TransformersNeuronX"
)
return NeuronModelRunner(vllm_config=vllm_config)
def get_neuronx_distributed_model_runner(self, vllm_config):
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
assert (self.lora_config is None), (
"LoRA is not supported for Speculative Decoding")
raise NotImplementedError(
"Speculative decoding is not supported for NeuronxDistributed")
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
# Swap not yet supported with Neuron backend.
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert num_cpu_blocks == 0
assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@property
def do_metadata_broadcast(self) -> bool:
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )
def execute_worker(self, worker_input: WorkerInput) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
1,
1,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.list_loras()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment