Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -9,4 +9,4 @@ except Exception as e: ...@@ -9,4 +9,4 @@ except Exception as e:
stacklevel=2) stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER" __commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.3.post1" __version__ = "0.5.4"
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
Type, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -12,7 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -12,7 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -41,7 +40,8 @@ class CPUModelInput(ModelRunnerInputBase): ...@@ -41,7 +40,8 @@ class CPUModelInput(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -135,7 +135,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -135,7 +135,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]: BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -204,8 +204,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -204,8 +204,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=None, seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=None, max_decode_seq_len=0,
num_prefills=len(seq_lens), num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens, num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0, num_decode_tokens=0,
...@@ -213,8 +213,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -213,8 +213,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -336,7 +335,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -336,7 +335,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
pin_memory=False) pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput( return CPUModelInput(
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
...@@ -345,7 +345,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -345,7 +345,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
) )
@torch.inference_mode() @torch.no_grad()
def execute_model( def execute_model(
self, self,
model_input: CPUModelInput, model_input: CPUModelInput,
...@@ -359,11 +359,16 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -359,11 +359,16 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": model_input.input_tokens, "input_ids":
"positions": model_input.input_positions, model_input.input_tokens,
"kv_caches": kv_caches, "positions":
"attn_metadata": model_input.attn_metadata, model_input.input_positions,
**(model_input.multi_modal_kwargs or {}), "kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
} }
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
......
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
...@@ -13,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -13,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, init_kmp_env from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput) LoraNotSupportedWorkerBase, WorkerInput)
...@@ -152,13 +153,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -152,13 +153,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
# try to initialize intel openmp optimized tunings
init_kmp_env()
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
if omp_cpuids == "all":
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner( self.model_runner: CPUModelRunner = CPUModelRunner(
model_config, model_config,
parallel_config, parallel_config,
...@@ -177,6 +183,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -177,6 +183,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cpu_cache: List[List[torch.Tensor]] self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None: def init_device(self) -> None:
if self.local_omp_cpuid != "all":
torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
self.init_distributed_environment() self.init_distributed_environment()
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
......
...@@ -8,10 +8,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -8,10 +8,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,6 +30,7 @@ class EmbeddingModelRunner( ...@@ -28,6 +30,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata) ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__( def __init__(
self, self,
...@@ -97,11 +100,16 @@ class EmbeddingModelRunner( ...@@ -97,11 +100,16 @@ class EmbeddingModelRunner(
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": model_input.input_tokens, "input_ids":
"positions": model_input.input_positions, model_input.input_tokens,
"kv_caches": kv_caches, "positions":
"attn_metadata": model_input.attn_metadata, model_input.input_positions,
**(model_input.multi_modal_kwargs or {}), "kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
} }
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
......
...@@ -3,9 +3,9 @@ import gc ...@@ -3,9 +3,9 @@ import gc
import time import time
import warnings import warnings
import weakref import weakref
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
Tuple, Type, TypeVar, Union) TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -23,6 +23,7 @@ except ImportError: ...@@ -23,6 +23,7 @@ except ImportError:
BatchPrefillWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
...@@ -40,7 +41,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -40,7 +41,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision) supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -94,7 +95,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -94,7 +95,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
...@@ -171,48 +172,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -171,48 +172,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata.""" """Build ModelInputForGPU from SequenceGroupMetadata."""
@dataclass # Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup: class InterDataForSeqGroup:
"""Intermediate data for the current sequence group.""" """Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str def __init__(
seq_ids: List[int] self,
is_prompt: bool *,
block_tables: Optional[Dict[int, List[int]]] # From sequence group metadata.
computed_block_nums: List[int] request_id: str,
n_seqs: int = 0 seq_ids: List[int],
is_prompt: bool,
# Input tokens and positions. block_tables: Optional[Dict[int, List[int]]],
input_tokens: List[List[int]] = field(default_factory=list) computed_block_nums: List[int],
input_positions: List[List[int]] = field(default_factory=list) n_seqs: int = 0,
# The sequence length (may be capped to the sliding window). # Input tokens and positions.
seq_lens: List[int] = field(default_factory=list) input_tokens: Optional[List[List[int]]] = None,
# The original sequence length (before applying sliding window). input_positions: Optional[List[List[int]]] = None,
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list) # The sequence length (may be capped to the sliding window).
# The query length. seq_lens: Optional[List[int]] = None,
query_lens: List[int] = field(default_factory=list) # The original sequence length (before applying sliding window).
# The number of tokens that are already computed. # This is used to compute slot mapping.
context_lens: List[int] = field(default_factory=list) orig_seq_lens: Optional[List[int]] = None,
# The current sliding window block. # The query length.
curr_sliding_window_blocks: List[int] = field(default_factory=list) query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
# LoRA inputs. context_lens: Optional[List[int]] = None,
lora_index_mapping: List[List[int]] = field(default_factory=list) # The current sliding window block.
lora_prompt_mapping: List[List[int]] = field(default_factory=list) curr_sliding_window_blocks: Optional[List[int]] = None,
lora_requests: Set[LoRARequest] = field(default_factory=set)
# LoRA inputs.
# Prompt adapter inputs. lora_index_mapping: Optional[List[List[int]]] = None,
prompt_adapter_index_mapping: List[int] = field(default_factory=list) lora_prompt_mapping: Optional[List[List[int]]] = None,
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list) lora_requests: Optional[Set[LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Prompt adapter inputs.
# Multi-modal inputs. prompt_adapter_index_mapping: Optional[List[int]] = None,
multi_modal_inputs: Optional[MultiModalInputs] = None prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False # Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
):
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit
self.__post_init__()
def __post_init__(self): def __post_init__(self):
self.n_seqs = len(self.seq_ids) self.n_seqs = len(self.seq_ids)
...@@ -457,6 +493,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -457,6 +493,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for per_seq_group_fn in self.per_seq_group_compute_fns: for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata) per_seq_group_fn(inter_data, seq_group_metadata)
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
def build(self) -> ModelInputForGPU: def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and """Finalize the builder intermediate data and
create on-device tensors. create on-device tensors.
...@@ -491,10 +533,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -491,10 +533,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
} }
batch_size = len(input_tokens) batch_size = len(input_tokens)
use_captured_graph = ( use_captured_graph = self._use_captured_graph(batch_size,
self.decode_only and not self.runner.model_config.enforce_eager max_decode_seq_len)
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
# If cuda graph can be used, pad tensors accordingly. # If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details. # See `capture_model` API for more details.
...@@ -539,9 +579,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -539,9 +579,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for inter_data in self.inter_data_list for inter_data in self.inter_data_list
]) ])
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
lora_index_mapping, **dict(index_mapping=lora_index_mapping,
lora_prompt_mapping, prompt_mapping=lora_prompt_mapping,
) is_prefill=not self.decode_only))
# Prompt adapter data. # Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set() prompt_adapter_requests: Set[PromptAdapterRequest] = set()
...@@ -569,8 +609,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -569,8 +609,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
data.multi_modal_inputs for data in self.inter_data_list data.multi_modal_inputs for data in self.inter_data_list
if data.multi_modal_inputs is not None if data.multi_modal_inputs is not None
] ]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
device=self.runner.device)
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
...@@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners. Helper class for shared methods between GPU model runners.
""" """
_model_input_cls: Type[TModelInputForGPU] _model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
def __init__( def __init__(
self, self,
...@@ -747,6 +787,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -747,6 +787,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
def save_sharded_state( def save_sharded_state(
self, self,
path: str, path: str,
...@@ -794,8 +839,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -794,8 +839,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
builder = ModelInputForGPUBuilder(weakref.proxy(self), builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata) builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore return builder.build() # type: ignore
...@@ -1040,9 +1084,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1040,9 +1084,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1] _indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[: _last_page_len_buffer = last_page_len_buffer[:
batch_size] batch_size]
num_qo_heads = ( num_qo_heads = (
self.model_config.get_num_attention_heads( self.model_config.get_num_attention_heads(
...@@ -1055,8 +1099,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1055,8 +1099,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
use_tensor_cores = False use_tensor_cores = False
decode_wrapper = \ decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer, decode_workspace_buffer, _indptr_buffer,
indices_buffer, last_page_len_buffer, "NHD", indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores) use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype( kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype) self.kv_cache_dtype, self.model_config.dtype)
...@@ -1114,9 +1158,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1114,9 +1158,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
[0] * batch_size, **dict(index_mapping=[0] * batch_size,
[0] * batch_size, prompt_mapping=[0] * batch_size,
) is_prefill=False))
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config: if self.prompt_adapter_config:
...@@ -1131,10 +1175,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1131,10 +1175,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model, self.attn_backend.get_name()) self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer graph_runner.flashinfer_indptr_buffer = _indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \ graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer _last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \ graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \ graph_runner.flashinfer_decode_wrapper = \
...@@ -1191,6 +1235,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1191,6 +1235,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
""" """
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
...@@ -1224,11 +1269,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1224,11 +1269,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids) seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, if get_pp_group().is_last_rank:
model_input.seq_lens, # Sampling metadata is only required for the final pp group
model_input.query_lens, generators = self.get_generators(finished_requests_ids)
self.device, sampling_metadata = SamplingMetadata.prepare(
self.pin_memory) seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None) if seq_group_metadata_list else None)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,
...@@ -1317,7 +1366,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1317,7 +1366,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs) **seqlen_agnostic_kwargs)
# Compute the logits in the last pipeline stage. # Compute the logits in the last pipeline stage.
......
...@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass. ModelRunnerInputBase subclass.
""" """
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod @abstractmethod
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
...@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input. Execute the model on the given input.
""" """
raise NotImplementedError raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Union)
import torch import torch
from torch import nn from torch import nn
...@@ -10,7 +9,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -10,7 +9,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -32,7 +31,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): ...@@ -32,7 +31,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -84,8 +83,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -84,8 +83,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
str, BatchedTensors]]: BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
...@@ -134,8 +133,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -134,8 +133,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
device=self.device)
return (input_tokens, input_positions, input_block_ids, seq_lens, return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -219,7 +217,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -219,7 +217,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
self.pin_memory) self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens, return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
...@@ -243,7 +242,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -243,7 +242,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids, input_block_ids=model_input.input_block_ids,
**(model_input.multi_modal_kwargs or {}), **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
) )
# Compute the logits. # Compute the logits.
......
from typing import List, Mapping, NamedTuple, Optional, Tuple from typing import List, NamedTuple, Optional, Tuple
import openvino as ov import openvino as ov
import torch import torch
...@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...@@ -25,7 +25,7 @@ class ModelInput(NamedTuple): ...@@ -25,7 +25,7 @@ class ModelInput(NamedTuple):
attn_metadata: Optional[OpenVINOAttentionMetadata] attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int] seq_lens: List[int]
query_lens: List[int] query_lens: List[int]
multi_modal_kwargs: Mapping[str, BatchedTensors] multi_modal_kwargs: BatchedTensorInputs
@classmethod @classmethod
def empty(cls, device): def empty(cls, device):
...@@ -265,8 +265,7 @@ class OpenVINOModelRunner: ...@@ -265,8 +265,7 @@ class OpenVINOModelRunner:
max_context_len=max_context_len_tensor, max_context_len=max_context_len_tensor,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
device=self.device)
return ModelInput( return ModelInput(
input_tokens, input_tokens,
...@@ -281,7 +280,7 @@ class OpenVINOModelRunner: ...@@ -281,7 +280,7 @@ class OpenVINOModelRunner:
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Mapping[str, BatchedTensors]]: SamplingMetadata, BatchedTensorInputs]:
# Prepare input tensors. # Prepare input tensors.
( (
input_tokens, input_tokens,
...@@ -324,11 +323,16 @@ class OpenVINOModelRunner: ...@@ -324,11 +323,16 @@ class OpenVINOModelRunner:
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": input_tokens, "input_ids":
"positions": input_positions, input_tokens,
"kv_caches": kv_caches, "positions":
"attn_metadata": attn_metadata, input_positions,
**(multi_modal_kwargs or {}), "kv_caches":
kv_caches,
"attn_metadata":
attn_metadata,
**MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
} }
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import patch
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
...@@ -26,7 +28,9 @@ if TYPE_CHECKING: ...@@ -26,7 +28,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored. # Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False _ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`. # FIXME(woosuk): A temporary hack to support `n > 1`.
...@@ -45,6 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase): ...@@ -45,6 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
num_samples: int num_samples: int
best_of: List[int] best_of: List[int]
seq_groups: List[List[int]] seq_groups: List[List[int]]
virtual_engine: int = 0
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -55,6 +60,9 @@ class ModelInputForTPU(ModelRunnerInputBase): ...@@ -55,6 +60,9 @@ class ModelInputForTPU(ModelRunnerInputBase):
"t": self.t, "t": self.t,
"p": self.p, "p": self.p,
"num_samples": self.num_samples, "num_samples": self.num_samples,
"best_of": self.best_of,
"seq_groups": self.seq_groups,
"virtual_engine": self.virtual_engine,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict return tensor_dict
...@@ -113,21 +121,45 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -113,21 +121,45 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device
model = get_model( # NOTE(woosuk): While the executor assigns the TP ranks to the worker
model_config=self.model_config, # process, the ranks can be different from the ranks internally assigned
load_config=self.load_config, # by the xm runtime. Therefore, there is a mismatch in the rank
device_config=self.device_config, # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
parallel_config=self.parallel_config, # This is not a problem in linear layers because all-reduce is
cache_config=self.cache_config, # rank-agnostic. However, it matters for all-gather as the ranks
scheduler_config=self.scheduler_config, # determine the order of concatenating the output tensors.
multimodal_config=self.multimodal_config, # As a workaround, we use the xm's rank assignment only when loading
lora_config=None, # the embedding weights.
) xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
lora_config=None,
)
model = model.eval() model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
model = ModelWrapper(model) model = ModelWrapper(model)
self.model = torch.compile(model, backend="openxla", fullgraph=True) # NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)
def _dummy_run( def _dummy_run(
self, self,
...@@ -384,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -384,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
best_of = [] best_of = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very t.append(sampling_params.temperature)
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P: if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError( raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend " "Top-p sampling is currently disabled for the TPU backend "
...@@ -463,10 +492,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -463,10 +492,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
tensor_dict, attn_backend=self.attn_backend) tensor_dict, attn_backend=self.attn_backend)
return model_input return model_input
@torch.no_grad()
def execute_model( def execute_model(
self, self,
model_input: ModelInputForTPU, model_input: ModelInputForTPU,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
...@@ -647,13 +677,23 @@ class ModelWrapper(nn.Module): ...@@ -647,13 +677,23 @@ class ModelWrapper(nn.Module):
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = logits / t.unsqueeze(dim=1) # Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P: if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1)) logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32) probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
next_token_ids = torch.multinomial(probs, sampled_token_ids = torch.multinomial(probs,
num_samples, num_samples,
replacement=True) replacement=True)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids return next_token_ids
......
...@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union ...@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
...@@ -70,13 +69,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -70,13 +69,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def init_device(self) -> None: def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
self.device = xm.xla_device()
self.device_config.device = self.device
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype) torch.set_default_dtype(self.model_config.dtype)
# NOTE(woosuk): This is just a hack to initialize the TP group. # NOTE(woosuk): This is just to initialize the TP group and broadcast
# This cannot perform the actual communication ops. # the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment( init_distributed_environment(
world_size=self.parallel_config.world_size, world_size=self.parallel_config.world_size,
rank=self.rank, rank=self.rank,
...@@ -88,6 +87,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -88,6 +87,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size) self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device) xm.set_rng_state(self.model_config.seed, self.device)
...@@ -100,7 +104,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -100,7 +104,10 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# Use persistent cache to avoid XLA recompilation. # Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation # NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results. # overhead because dynamo does not cache the compiled results.
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False) # NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid
# race conditions.
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH,
readonly=not self.is_driver_worker)
def load_model(self): def load_model(self):
self.model_runner.load_model() self.model_runner.load_model()
...@@ -200,8 +207,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -200,8 +207,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
@property @property
def do_metadata_broadcast(self) -> bool: def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP. return self.parallel_config.tensor_parallel_size > 1
return False
@property @property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
......
...@@ -186,7 +186,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -186,7 +186,9 @@ class Worker(LocalOrDistributedWorkerBase):
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, ( assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was " "Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes() cache_block_size = self.get_cache_block_size_bytes()
......
...@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -267,7 +267,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -267,7 +267,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank: if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors( intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict()) get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine] model_input, self.kv_cache[worker_input.virtual_engine]
...@@ -276,14 +277,17 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -276,14 +277,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# output is IntermediateTensors # output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors) get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None] return [None]
# output is List[SamplerOutput] # output is List[SamplerOutput]
return output return output
def _execute_model_spmd( def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """
Execute model in Single Program Multiple Data (SPMD) fashion. Execute model in Single Program Multiple Data (SPMD) fashion.
...@@ -307,7 +311,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -307,7 +311,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return self.model_runner.execute_model( return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine] model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None) if self.kv_cache is not None else None, intermediate_tensors)
class WorkerWrapperBase: class WorkerWrapperBase:
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
Type, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,7 +13,7 @@ from vllm.inputs import INPUT_REGISTRY ...@@ -14,7 +13,7 @@ from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
...@@ -49,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase): ...@@ -49,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -246,7 +245,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -246,7 +245,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
pin_memory=False) pin_memory=False,
generators=self.get_generators(finished_requests_ids))
# Broadcast the metadata. # Broadcast the metadata.
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
...@@ -375,11 +375,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -375,11 +375,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": model_input.input_tokens, "input_ids":
"positions": model_input.input_positions, model_input.input_tokens,
"kv_caches": kv_caches, "positions":
"attn_metadata": model_input.attn_metadata, model_input.input_positions,
**(model_input.multi_modal_kwargs or {}), "kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
} }
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
...@@ -403,7 +408,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -403,7 +408,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]: BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -495,8 +500,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -495,8 +500,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int), block_tables=torch.tensor([], device=self.device, dtype=torch.int),
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
device=self.device)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -138,7 +138,9 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -138,7 +138,9 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, ( assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was " "Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes() cache_block_size = self.get_cache_block_size_bytes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment