Unverified Commit e90fc21f authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Hardware][Neuron] Refactor neuron support (#3471)

parent ea5f14e6
......@@ -12,7 +12,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(
model="openlm-research/open_llama_3b",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
......@@ -24,7 +24,8 @@ llm = LLM(
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron")
device="neuron",
tensor_parallel_size=2)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
......
......@@ -33,7 +33,7 @@ def test_worker_apply_lora(sql_lora_files):
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.init_device()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
......
......@@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
vocab_size = 32_000
......@@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
......@@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
......@@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
......@@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):
@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
......@@ -499,11 +499,11 @@ def test_init_model():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()
draft_worker.init_model.assert_called_once()
draft_worker.init_device.assert_called_once()
target_worker.init_model.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()
......
......@@ -123,7 +123,7 @@ def create_worker(cls: type,
is_driver_worker=is_driver_worker,
)
worker.init_model()
worker.init_device()
worker.load_model()
cache_config.num_gpu_blocks = num_gpu_blocks
......
......@@ -30,7 +30,7 @@ def test_swap() -> None:
)
# Initialize the worker.
worker.init_model()
worker.init_device()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.warm_up_model()
......
......@@ -474,14 +474,6 @@ class ParallelConfig:
placement_group: Optional["PlacementGroup"] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding
# within vLLM directly. Transformer-neuronx would take
# neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
......@@ -491,8 +483,7 @@ class ParallelConfig:
self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()
......@@ -591,10 +582,6 @@ class DeviceConfig:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass
class LoRAConfig:
......
......@@ -325,7 +325,12 @@ class AsyncLLMEngine:
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
device_config = engine_configs[4]
if device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
......
......@@ -125,9 +125,13 @@ class LLMEngine:
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
# Initialize the cluster and specify the executor class.
if parallel_config.worker_use_ray:
if device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
......
import importlib
from typing import Dict, List, Optional
from vllm.lora.request import LoRARequest
......@@ -13,12 +12,6 @@ from vllm.utils import (get_ip, get_open_port, get_distributed_init_method,
logger = init_logger(__name__)
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
class GPUExecutor(ExecutorBase):
......@@ -44,17 +37,10 @@ class GPUExecutor(ExecutorBase):
# Profile the memory usage and initialize the cache.
self._init_cache()
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
from vllm.worker.worker import Worker
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
......@@ -73,7 +59,7 @@ class GPUExecutor(ExecutorBase):
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_model()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_cache(self) -> None:
......
from typing import Dict, List, Optional
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
self.cache_config.num_cpu_blocks = 0
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def list_loras(self) -> List[int]:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
......@@ -3,7 +3,6 @@ import copy
from collections import defaultdict
import os
import pickle
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
......@@ -25,12 +24,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
......@@ -73,13 +66,6 @@ class RayGPUExecutor(ExecutorBase):
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
......@@ -155,7 +141,7 @@ class RayGPUExecutor(ExecutorBase):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
from vllm.worker.worker import Worker
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
......@@ -201,7 +187,7 @@ class RayGPUExecutor(ExecutorBase):
# FIXME(woosuk): We are not properly initializing cupy NCCL when
# we have multiple nodes.
self._run_workers("init_model",
self._run_workers("init_device",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers(
......
......@@ -799,8 +799,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
self.device = device
@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states
def logits_as_input(self):
return self.base_layer.logits_as_input
@property
def vocab_size(self):
......
from typing import List, Optional
import torch
from vllm.utils import in_wsl
from vllm.utils import is_pin_memory_available
class LoRALayerWeights:
......@@ -64,7 +64,7 @@ class LoRALayerWeights:
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl()
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
......
......@@ -11,7 +11,7 @@ import torch
from torch import nn
from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.utils import LRUCache, is_pin_memory_available
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_logits_processor)
......@@ -143,7 +143,7 @@ class LoRAModel:
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl()
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
......
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed, get_model
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
"SamplingMetadata",
"set_random_seed",
]
from dataclasses import dataclass, fields
from typing import Optional, List, Any, Dict
from typing import TYPE_CHECKING, Optional, List, Any, Dict
import torch
from xformers.ops.fmha.attn_bias import AttentionBias
if TYPE_CHECKING:
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass
......@@ -82,7 +83,7 @@ class InputMetadata:
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
self.attn_bias: Optional[List["AttentionBias"]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
......
......@@ -4,8 +4,6 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.utils import is_neuron
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -23,7 +21,8 @@ class LogitsProcessor(nn.Module):
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0) -> None:
scale: Optional[float] = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
......@@ -31,8 +30,8 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
......@@ -43,7 +42,7 @@ class LogitsProcessor(nn.Module):
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_hidden_states:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
......
......@@ -4,13 +4,13 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
class Sampler(nn.Module):
......
......@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip, is_neuron
from vllm.utils import is_hip
logger = init_logger(__name__)
......@@ -63,12 +63,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention",
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": "neuron.llama",
"MistralForCausalLM": "neuron.mistral"
}
class ModelRegistry:
......@@ -85,15 +79,8 @@ class ModelRegistry:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
......
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
with torch.inference_mode():
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids.flatten())
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.llama.model import LlamaForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers.models.llama import LlamaForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = LlamaForSampling.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment