Unverified Commit e90fc21f authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Hardware][Neuron] Refactor neuron support (#3471)

parent ea5f14e6
"""Inference-only Mistral model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
import os
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> SamplerOutput:
with torch.inference_mode():
seq_ids = []
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.mistral.model import MistralForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers import MistralForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = MistralForCausalLM.from_pretrained(
model_name_or_path, low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = MistralForSampling.from_pretrained(
split_model_dir, **kwargs)
self.model.to_neuron()
"""Utilities for selecting and loading models."""
from typing import Type
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Optional, Type
import torch
import torch.nn as nn
import transformers
from transformers import PretrainedConfig
from vllm.config import ModelConfig, DeviceConfig
from vllm.model_executor.models import ModelRegistry
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
......@@ -20,31 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = {
torch.float32: "f32",
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
"MistralForSampling", "MistralForCausalLM")
}
class NeuronCasualLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
logits = self.model(input_ids,
cache_ids=positions,
start_ids=input_block_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
from transformers_neuronx.config import (NeuronConfig,
ContinuousBatchingConfig)
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")
model_class = _get_model_architecture(model_config.hf_config)
linear_method = None
# Create a model instance.
model = model_class(model_config.hf_config, linear_method)
model = NeuronCasualLM(model_config.hf_config)
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
......@@ -54,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
tp_degree=parallel_config.neuron_tp_degree,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
......
......@@ -4,11 +4,11 @@ from typing import Dict, List, Optional, Tuple
import torch
import random
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl, is_neuron
from vllm.model_executor.layers.ops.sample import (
get_num_triton_sampler_splits)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import is_pin_memory_available
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
......@@ -213,7 +213,7 @@ class SamplingTensors:
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = not in_wsl() and not is_neuron()
pin_memory = is_pin_memory_available()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
......
"""Utils for model executor."""
import random
import importlib
from typing import Any, Dict, Optional
import numpy as np
import torch
from vllm.config import DeviceConfig, ModelConfig
DEVICE_TO_MODEL_LOADER_MAP = {
"cuda": "model_loader",
"neuron": "neuron_model_loader",
}
def set_random_seed(seed: int) -> None:
random.seed(seed)
......@@ -41,12 +33,3 @@ def set_weight_attrs(
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> torch.nn.Module:
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
imported_model_loader = importlib.import_module(
f"vllm.model_executor.{model_loader_module}")
get_model_fn = imported_model_loader.get_model
return get_model_fn(model_config, device_config, **kwargs)
......@@ -2,7 +2,7 @@ import torch
from dataclasses import dataclass
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from typing import Optional
from vllm.utils import in_wsl
from vllm.utils import is_pin_memory_available
import time
from typing import Callable
......@@ -63,7 +63,7 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = not in_wsl()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
......
......@@ -27,8 +27,8 @@ class MultiStepWorker(Worker):
self._proposer: Optional[DraftModelTop1Proposer] = None
def init_model(self):
super().init_model()
def init_device(self):
super().init_device()
self._proposer = DraftModelTop1Proposer(
self,
......
......@@ -79,13 +79,13 @@ class SpecDecodeWorker:
self.scorer: SpeculativeScorer = None
def init_model(self) -> None:
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_model()
self.proposer_worker.init_model()
self.scorer_worker.init_device()
self.proposer_worker.init_device()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
......
......@@ -338,7 +338,27 @@ def create_kv_caches_with_random(
return key_caches, value_caches
class measure_cuda_memory:
@lru_cache
def print_warning_once(msg: str) -> None:
logger.warning(msg)
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
if in_wsl():
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
return False
elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
return True
class CudaMemoryProfiler:
def __init__(self, device=None):
self.device = device
......@@ -360,3 +380,44 @@ class measure_cuda_memory:
# Force garbage collection
gc.collect()
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
def async_tensor_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
"""Expand the tensor to the target_dims."""
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
......@@ -5,7 +5,7 @@ import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
......@@ -38,10 +38,6 @@ class CacheEngine:
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
# Skip initializing KV cache for Neuron backend.
if is_neuron():
return
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
......@@ -90,12 +86,7 @@ class CacheEngine:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
pin_memory = not in_wsl()
if not pin_memory:
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
pin_memory = is_pin_memory_available()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
......
import contextlib
import time
from typing import Dict, List, Optional, Tuple, Set, Union
from typing import Dict, List, Optional, Tuple, Set
import numpy as np
import torch
......@@ -9,7 +9,8 @@ import torch.nn as nn
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor import InputMetadata, SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
......@@ -21,7 +22,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl, measure_cuda_memory
from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
logger = init_logger(__name__)
......@@ -79,16 +82,11 @@ class ModelRunner:
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if self.device_config.is_neuron:
self.model_config.enforce_eager = True
def load_model(self) -> None:
with measure_cuda_memory() as m:
with CudaMemoryProfiler() as m:
self.model = get_model(self.model_config,
self.device_config,
lora_config=self.lora_config,
......@@ -238,7 +236,7 @@ class ModelRunner:
device=self.device)
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
block_tables = make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
......@@ -395,7 +393,7 @@ class ModelRunner:
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = _make_tensor_with_pad(
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
......@@ -436,7 +434,6 @@ class ModelRunner:
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -469,7 +466,7 @@ class ModelRunner:
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed)
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
......@@ -494,17 +491,17 @@ class ModelRunner:
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: _maybe_expand_dim(
_async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=pin_memory), 2, 2)
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
......@@ -910,27 +907,6 @@ def _maybe_cupy_nccl():
yield
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def _make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
......@@ -944,21 +920,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _async_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def _maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.neuron_model_loader import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class NeuronModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.pin_memory = is_pin_memory_available()
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
context_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert prompt_lens is not None
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)
hidden_states = self.model(
input_ids=input_tokens,
positions=input_positions,
input_block_ids=input_block_ids,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
"""A Neuron worker class."""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.neuron_model_runner import NeuronModelRunner
class Worker:
class NeuronWorker:
"""A worker class that executes the model on a group of neuron cores.
"""
......@@ -26,168 +21,32 @@ class Worker:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config)
def init_model(self) -> None:
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
distributed_backend="gloo")
# Initialize the model.
def init_device(self) -> None:
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int = 128,
gpu_memory_utilization: float = 0.9,
cpu_swap_space: int = 0,
cache_dtype: str = "float16",
) -> Tuple[int, int]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as
num_cpu_blocks."""
num_gpu_blocks = self.scheduler_config.max_num_seqs
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
# Warm up is maintained in transformers-neuronx
pass
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
if cache_events is not None:
raise NotImplementedError(
"cache operations are not implemented for neuron backend.")
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
output = self.model_runner.execute_model(seq_group_metadata_list)
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
distributed_backend: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
distributed_backend = (distributed_backend
if distributed_backend else "nccl")
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1))
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
......@@ -67,7 +67,7 @@ class Worker:
self.cache_engine = None
self.gpu_cache = None
def init_model(self, cupy_port: Optional[int] = None) -> None:
def init_device(self, cupy_port: Optional[int] = None) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
......@@ -91,7 +91,7 @@ class Worker:
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method)
# Initialize the model.
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment