Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
......@@ -9,4 +9,4 @@ except Exception as e:
stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.3.post1"
__version__ = "0.5.4"
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import nn
......@@ -12,7 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -41,7 +40,8 @@ class CPUModelInput(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -135,7 +135,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]:
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
......@@ -204,8 +204,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=None,
max_decode_seq_len=None,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
......@@ -213,8 +213,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
slot_mapping=slot_mapping,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
......@@ -336,7 +335,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
......@@ -345,7 +345,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
multi_modal_kwargs=multi_modal_kwargs,
)
@torch.inference_mode()
@torch.no_grad()
def execute_model(
self,
model_input: CPUModelInput,
......@@ -359,11 +359,16 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
model_executable = self.model
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
......
......@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
......@@ -13,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, init_kmp_env
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
......@@ -152,13 +153,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
# try to initialize intel openmp optimized tunings
init_kmp_env()
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
if omp_cpuids == "all":
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner(
model_config,
parallel_config,
......@@ -177,6 +183,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
......
......@@ -8,10 +8,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__)
......@@ -28,6 +30,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__(
self,
......@@ -97,11 +100,16 @@ class EmbeddingModelRunner(
kv_caches = [None] * num_layers
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
......
......@@ -3,9 +3,9 @@ import gc
import time
import warnings
import weakref
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
......@@ -23,6 +23,7 @@ except ImportError:
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
......@@ -40,7 +41,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -94,7 +95,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
......@@ -171,48 +172,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@dataclass
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0
# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)
# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)
# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)
# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False
def __init__(
self,
*,
# From sequence group metadata.
request_id: str,
seq_ids: List[int],
is_prompt: bool,
block_tables: Optional[Dict[int, List[int]]],
computed_block_nums: List[int],
n_seqs: int = 0,
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,
# The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
context_lens: Optional[List[int]] = None,
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,
# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,
# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
):
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit
self.__post_init__()
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
......@@ -457,6 +493,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
......@@ -491,10 +533,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
}
batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
use_captured_graph = self._use_captured_graph(batch_size,
max_decode_seq_len)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
......@@ -539,9 +579,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only))
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
......@@ -569,8 +609,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
data.multi_modal_inputs for data in self.inter_data_list
if data.multi_modal_inputs is not None
]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.runner.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
......@@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
def __init__(
self,
......@@ -747,6 +787,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
def save_sharded_state(
self,
path: str,
......@@ -794,8 +839,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
......@@ -1040,9 +1084,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:
batch_size]
_indptr_buffer = indptr_buffer[:batch_size + 1]
_last_page_len_buffer = last_page_len_buffer[:
batch_size]
num_qo_heads = (
self.model_config.get_num_attention_heads(
......@@ -1055,8 +1099,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer,
indices_buffer, last_page_len_buffer, "NHD",
decode_workspace_buffer, _indptr_buffer,
indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
......@@ -1114,9 +1158,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
**dict(index_mapping=[0] * batch_size,
prompt_mapping=[0] * batch_size,
is_prefill=False))
self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config:
......@@ -1131,10 +1175,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer
graph_runner.flashinfer_indptr_buffer = _indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer
_last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
......@@ -1191,6 +1235,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
......@@ -1224,11 +1269,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
......@@ -1317,7 +1366,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
# Compute the logits in the last pipeline stage.
......
......@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
"""
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
......@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input.
"""
raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -10,7 +9,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -32,7 +31,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -84,8 +83,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
......@@ -134,8 +133,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
......@@ -219,7 +217,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
......@@ -243,7 +242,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**(model_input.multi_modal_kwargs or {}),
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
# Compute the logits.
......
from typing import List, Mapping, NamedTuple, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
......@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
......@@ -25,7 +25,7 @@ class ModelInput(NamedTuple):
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_kwargs: Mapping[str, BatchedTensors]
multi_modal_kwargs: BatchedTensorInputs
@classmethod
def empty(cls, device):
......@@ -265,8 +265,7 @@ class OpenVINOModelRunner:
max_context_len=max_context_len_tensor,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return ModelInput(
input_tokens,
......@@ -281,7 +280,7 @@ class OpenVINOModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Mapping[str, BatchedTensors]]:
SamplingMetadata, BatchedTensorInputs]:
# Prepare input tensors.
(
input_tokens,
......@@ -324,11 +323,16 @@ class OpenVINOModelRunner:
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
**(multi_modal_kwargs or {}),
"input_ids":
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
attn_metadata,
**MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
......
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
......@@ -26,7 +28,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
......@@ -45,6 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
num_samples: int
best_of: List[int]
seq_groups: List[List[int]]
virtual_engine: int = 0
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -55,6 +60,9 @@ class ModelInputForTPU(ModelRunnerInputBase):
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
"best_of": self.best_of,
"seq_groups": self.seq_groups,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
......@@ -113,21 +121,45 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def load_model(self) -> None:
self.device = self.device_config.device
model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
lora_config=None,
)
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
lora_config=None,
)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model, backend="openxla", fullgraph=True)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)
def _dummy_run(
self,
......@@ -384,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
best_of = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
......@@ -463,10 +492,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
tensor_dict, attn_backend=self.attn_backend)
return model_input
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
......@@ -647,13 +677,23 @@ class ModelWrapper(nn.Module):
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = logits / t.unsqueeze(dim=1)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
......
......@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr
import vllm.envs as envs
......@@ -70,13 +69,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
self.device = xm.xla_device()
self.device_config.device = self.device
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# NOTE(woosuk): This is just a hack to initialize the TP group.
# This cannot perform the actual communication ops.
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
......@@ -88,6 +87,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
......@@ -100,7 +104,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results.
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False)
# NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid
# race conditions.
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH,
readonly=not self.is_driver_worker)
def load_model(self):
self.model_runner.load_model()
......@@ -200,8 +207,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
@property
def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP.
return False
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
......
......@@ -186,7 +186,9 @@ class Worker(LocalOrDistributedWorkerBase):
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was "
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
......
......@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
......@@ -267,7 +267,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
......@@ -276,14 +277,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
# output is List[SamplerOutput]
return output
def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
......@@ -307,7 +311,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
if self.kv_cache is not None else None, intermediate_tensors)
class WorkerWrapperBase:
......
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
......@@ -14,7 +13,7 @@ from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
......@@ -49,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
......@@ -246,7 +245,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
......@@ -375,11 +375,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
model_executable = self.model
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
......@@ -403,7 +408,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]:
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
......@@ -495,8 +500,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
......@@ -138,7 +138,9 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was "
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment