Unverified Commit e90fc21f authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Hardware][Neuron] Refactor neuron support (#3471)

parent ea5f14e6
"""Inference-only Mistral model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
import os
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> SamplerOutput:
with torch.inference_mode():
seq_ids = []
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.mistral.model import MistralForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers import MistralForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = MistralForCausalLM.from_pretrained(
model_name_or_path, low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = MistralForSampling.from_pretrained(
split_model_dir, **kwargs)
self.model.to_neuron()
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading neuron models."""
from typing import Type import importlib
import os
from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, DeviceConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP = { TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32", "auto": "f32",
...@@ -20,31 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = { ...@@ -20,31 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = {
torch.float32: "f32", torch.float32: "f32",
} }
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
"MistralForSampling", "MistralForCausalLM")
}
class NeuronCasualLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
logits = self.model(input_ids,
cache_ids=positions,
start_ids=input_block_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch) if arch in _NEURON_SUPPORTED_MODELS:
if model_cls is not None: return arch
return model_cls
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported on Neuron "
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def get_model(model_config: ModelConfig, device_config: DeviceConfig, def get_neuron_model(model_config: ModelConfig,
**kwargs) -> nn.Module: parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
from transformers_neuronx.config import (NeuronConfig, from transformers_neuronx.config import (NeuronConfig,
ContinuousBatchingConfig) ContinuousBatchingConfig)
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")
model_class = _get_model_architecture(model_config.hf_config)
linear_method = None
# Create a model instance. # Create a model instance.
model = model_class(model_config.hf_config, linear_method) model = NeuronCasualLM(model_config.hf_config)
continuous_batching_config = ContinuousBatchingConfig( continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs) batch_size_for_shared_caches=scheduler_config.max_num_seqs)
...@@ -54,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, ...@@ -54,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(
model_config.model, model_config.model,
model_config.download_dir, tp_degree=parallel_config.tensor_parallel_size,
model_config.load_format,
model_config.revision,
tp_degree=parallel_config.neuron_tp_degree,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config, neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len], context_length_estimate=[scheduler_config.max_model_len],
......
...@@ -4,11 +4,11 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,11 +4,11 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import random import random
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl, is_neuron
from vllm.model_executor.layers.ops.sample import ( from vllm.model_executor.layers.ops.sample import (
get_num_triton_sampler_splits) get_num_triton_sampler_splits)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import is_pin_memory_available
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
...@@ -213,7 +213,7 @@ class SamplingTensors: ...@@ -213,7 +213,7 @@ class SamplingTensors:
dtype: torch.dtype) -> "SamplingTensors": dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
pin_memory = not in_wsl() and not is_neuron() pin_memory = is_pin_memory_available()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [ prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens)) tokens + [vocab_size] * (prompt_max_len - len(tokens))
......
"""Utils for model executor.""" """Utils for model executor."""
import random import random
import importlib
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
from vllm.config import DeviceConfig, ModelConfig
DEVICE_TO_MODEL_LOADER_MAP = {
"cuda": "model_loader",
"neuron": "neuron_model_loader",
}
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
random.seed(seed) random.seed(seed)
...@@ -41,12 +33,3 @@ def set_weight_attrs( ...@@ -41,12 +33,3 @@ def set_weight_attrs(
assert not hasattr( assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}") weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value) setattr(weight, key, value)
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> torch.nn.Module:
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
imported_model_loader = importlib.import_module(
f"vllm.model_executor.{model_loader_module}")
get_model_fn = imported_model_loader.get_model
return get_model_fn(model_config, device_config, **kwargs)
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from typing import Optional from typing import Optional
from vllm.utils import in_wsl from vllm.utils import is_pin_memory_available
import time import time
from typing import Callable from typing import Callable
...@@ -63,7 +63,7 @@ class AsyncMetricsCollector: ...@@ -63,7 +63,7 @@ class AsyncMetricsCollector:
self._in_flight_copy: Optional[torch.cuda.Event] = None self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = not in_wsl() pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor( self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory) 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor( self._aggregate_num_emitted_tokens = torch.tensor(
......
...@@ -27,8 +27,8 @@ class MultiStepWorker(Worker): ...@@ -27,8 +27,8 @@ class MultiStepWorker(Worker):
self._proposer: Optional[DraftModelTop1Proposer] = None self._proposer: Optional[DraftModelTop1Proposer] = None
def init_model(self): def init_device(self):
super().init_model() super().init_device()
self._proposer = DraftModelTop1Proposer( self._proposer = DraftModelTop1Proposer(
self, self,
......
...@@ -79,13 +79,13 @@ class SpecDecodeWorker: ...@@ -79,13 +79,13 @@ class SpecDecodeWorker:
self.scorer: SpeculativeScorer = None self.scorer: SpeculativeScorer = None
def init_model(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
""" """
# The scorer worker model is initialized first in case the proposer # The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker. # model has a smaller TP degree than the target worker.
self.scorer_worker.init_model() self.scorer_worker.init_device()
self.proposer_worker.init_model() self.proposer_worker.init_device()
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank)
......
...@@ -338,7 +338,27 @@ def create_kv_caches_with_random( ...@@ -338,7 +338,27 @@ def create_kv_caches_with_random(
return key_caches, value_caches return key_caches, value_caches
class measure_cuda_memory: @lru_cache
def print_warning_once(msg: str) -> None:
logger.warning(msg)
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
if in_wsl():
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
return False
elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
return True
class CudaMemoryProfiler:
def __init__(self, device=None): def __init__(self, device=None):
self.device = device self.device = device
...@@ -360,3 +380,44 @@ class measure_cuda_memory: ...@@ -360,3 +380,44 @@ class measure_cuda_memory:
# Force garbage collection # Force garbage collection
gc.collect() gc.collect()
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
def async_tensor_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
"""Expand the tensor to the target_dims."""
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -38,10 +38,6 @@ class CacheEngine: ...@@ -38,10 +38,6 @@ class CacheEngine:
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
# Skip initializing KV cache for Neuron backend.
if is_neuron():
return
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype self.dtype = model_config.dtype
else: else:
...@@ -90,12 +86,7 @@ class CacheEngine: ...@@ -90,12 +86,7 @@ class CacheEngine:
cpu_cache: List[KVCache] = [] cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape() key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape() value_block_shape = self.get_value_block_shape()
pin_memory = not in_wsl() pin_memory = is_pin_memory_available()
if not pin_memory:
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
for _ in range(self.num_layers): for _ in range(self.num_layers):
key_blocks = torch.empty( key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
......
import contextlib import contextlib
import time import time
from typing import Dict, List, Optional, Tuple, Set, Union from typing import Dict, List, Optional, Tuple, Set
import numpy as np import numpy as np
import torch import torch
...@@ -9,7 +9,8 @@ import torch.nn as nn ...@@ -9,7 +9,8 @@ import torch.nn as nn
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import InputMetadata, SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
...@@ -21,7 +22,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata ...@@ -21,7 +22,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl, measure_cuda_memory from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -79,16 +82,11 @@ class ModelRunner: ...@@ -79,16 +82,11 @@ class ModelRunner:
# The shape of the cached block table will be # The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size). # (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling. self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result self.pin_memory = is_pin_memory_available()
self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if self.device_config.is_neuron:
self.model_config.enforce_eager = True
def load_model(self) -> None: def load_model(self) -> None:
with measure_cuda_memory() as m: with CudaMemoryProfiler() as m:
self.model = get_model(self.model_config, self.model = get_model(self.model_config,
self.device_config, self.device_config,
lora_config=self.lora_config, lora_config=self.lora_config,
...@@ -238,7 +236,7 @@ class ModelRunner: ...@@ -238,7 +236,7 @@ class ModelRunner:
device=self.device) device=self.device)
# Prepare prefix block tables # Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad( block_tables = make_tensor_with_pad(
prefix_block_tables, prefix_block_tables,
max_len=max_prompt_block_table_len, max_len=max_prompt_block_table_len,
pad=0, pad=0,
...@@ -395,7 +393,7 @@ class ModelRunner: ...@@ -395,7 +393,7 @@ class ModelRunner:
else: else:
max_block_table_len = max( max_block_table_len = max(
len(block_table) for block_table in block_tables) len(block_table) for block_table in block_tables)
block_tables = _make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len, max_len=max_block_table_len,
pad=0, pad=0,
...@@ -436,7 +434,6 @@ class ModelRunner: ...@@ -436,7 +434,6 @@ class ModelRunner:
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -469,7 +466,7 @@ class ModelRunner: ...@@ -469,7 +466,7 @@ class ModelRunner:
if sampling_params.seed is not None: if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator( seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed) device=self.device).manual_seed(sampling_params.seed)
else: else:
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
selected_token_indices.extend( selected_token_indices.extend(
...@@ -494,17 +491,17 @@ class ModelRunner: ...@@ -494,17 +491,17 @@ class ModelRunner:
if sampling_params.seed is not None: if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator) generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices, selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=self.device, target_device=self.device,
pin_memory=not self.in_wsl) pin_memory=self.pin_memory)
categorized_sample_indices = { categorized_sample_indices = {
t: _maybe_expand_dim( t: maybe_expand_dim(
_async_h2d(seq_ids, async_tensor_h2d(seq_ids,
dtype=torch.int, dtype=torch.int,
target_device=self.device, target_device=self.device,
pin_memory=pin_memory), 2, 2) pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items() for t, seq_ids in categorized_sample_indices.items()
} }
...@@ -910,27 +907,6 @@ def _maybe_cupy_nccl(): ...@@ -910,27 +907,6 @@ def _maybe_cupy_nccl():
yield yield
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def _make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int: def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size. """Returns the padded batch size given actual batch size.
...@@ -944,21 +920,3 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -944,21 +920,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else: else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _async_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def _maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.neuron_model_loader import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class NeuronModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.pin_memory = is_pin_memory_available()
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
context_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert prompt_lens is not None
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)
hidden_states = self.model(
input_ids=input_tokens,
positions=input_positions,
input_block_ids=input_block_ids,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
"""A Neuron worker class.""" """A Neuron worker class."""
from typing import Dict, List, Optional, Tuple from typing import List, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig) SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.model_runner import ModelRunner
class Worker: class NeuronWorker:
"""A worker class that executes the model on a group of neuron cores. """A worker class that executes the model on a group of neuron cores.
""" """
...@@ -26,168 +21,32 @@ class Worker: ...@@ -26,168 +21,32 @@ class Worker:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, self.model_runner = NeuronModelRunner(model_config, parallel_config,
parallel_config, scheduler_config, device_config)
scheduler_config,
device_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
def init_model(self) -> None: def init_device(self) -> None:
# Initialize the distributed environment. # Set random seed.
_init_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
distributed_backend="gloo")
# Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def load_model(self): def load_model(self):
self.model_runner.load_model() self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int = 128,
gpu_memory_utilization: float = 0.9,
cpu_swap_space: int = 0,
cache_dtype: str = "float16",
) -> Tuple[int, int]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as
num_cpu_blocks."""
num_gpu_blocks = self.scheduler_config.max_num_seqs
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
# Warm up is maintained in transformers-neuronx
pass
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
if cache_events is not None:
raise NotImplementedError(
"cache operations are not implemented for neuron backend.")
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return {}
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list)
self.gpu_cache)
return output return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
distributed_backend: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
distributed_backend = (distributed_backend
if distributed_backend else "nccl")
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1))
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
...@@ -67,7 +67,7 @@ class Worker: ...@@ -67,7 +67,7 @@ class Worker:
self.cache_engine = None self.cache_engine = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self, cupy_port: Optional[int] = None) -> None: def init_device(self, cupy_port: Optional[int] = None) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow # the synchronization point. This causes the memory usage to grow
...@@ -91,7 +91,7 @@ class Worker: ...@@ -91,7 +91,7 @@ class Worker:
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method) cupy_port, self.distributed_init_method)
# Initialize the model. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def load_model(self): def load_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment