Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
from contextlib import contextmanager
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
......@@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
......@@ -68,7 +68,7 @@ def create_sequence_group_output(
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
logprobs: Dict[Optional[int], Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
......
......@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
RWConfig)
from vllm.transformers_utils.configs import (ChameleonConfig, ChatGLMConfig,
DbrxConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
......@@ -18,6 +18,7 @@ else:
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chameleon": ChameleonConfig,
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
......
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
......@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [
"ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig",
"DbrxConfig",
"MPTConfig",
......
from typing import List, Optional
from transformers import PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class ChameleonConfig(PretrainedConfig):
model_type = "chameleon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
model_parallel_size=1,
swin_norm=False,
vq_config=None,
vocabulary_map=None,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_bias = mlp_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_parallel_size = model_parallel_size
self.swin_norm = swin_norm
if vq_config is None:
vq_config = {}
self.vq_config = ChameleonVQVAEConfig(**vq_config)
self.vocabulary_map = vocabulary_map
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, "
f"`type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in [
"linear", "dynamic"
]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', "
f"'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, "
f"got {rope_scaling_factor}")
class ChameleonVQVAEConfig(PretrainedConfig):
model_type = "chameleon_vqgan"
def __init__(
self,
embed_dim: int = 256,
num_embeddings: int = 8192,
double_latent: bool = False,
latent_channels: int = 256,
resolution: int = 512,
in_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
num_res_blocks: int = 2,
attn_resolutions: Optional[List[int]] = None,
dropout: float = 0.0,
attn_type: str = "vanilla",
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.resolution = resolution
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.dropout = dropout
self.attn_type = attn_type
self.initializer_range = initializer_range
......@@ -165,6 +165,12 @@ class Detokenizer:
return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
......@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
return new_tokens, prefix_offset, read_offset
......
......@@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
......@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_local_path, e)
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer
......
from typing import Optional
from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
......@@ -16,18 +16,22 @@ else:
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs)
if tokenizer_pool_config.pool_type == "ray":
tokenizer_cls = TokenizerGroup
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
tokenizer_cls = tokenizer_pool_config.pool_type
elif tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
tokenizer_cls = RayTokenizerGroupPool
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
......@@ -3,12 +3,19 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
@classmethod
@abstractmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "BaseTokenizerGroup":
pass
@abstractmethod
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
......
......@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls = TokenizerGroup
@classmethod
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "RayTokenizerGroupPool":
if not tokenizer_pool_config:
raise ValueError("tokenizer_pool_config must not be None.")
ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0
})
......
......@@ -2,6 +2,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
......@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "TokenizerGroup":
return cls(**init_kwargs)
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
......
......@@ -16,12 +16,12 @@ import requests
import torch
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
"vllm/do_not_track")
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "do_not_track")
_USAGE_STATS_ENABLED = None
_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER
......@@ -205,7 +205,8 @@ class UsageMessage:
def _send_to_server(self, data):
try:
requests.post(_USAGE_STATS_SERVER, json=data)
global_http_client = global_http_connection.get_sync_client()
global_http_client.post(_USAGE_STATS_SERVER, json=data)
except requests.exceptions.RequestException:
# silently ignore unless we are using debug log
logging.debug("Failed to send usage data to server")
......
......@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union)
import numpy as np
import numpy.typing as npt
import psutil
import torch
import torch.types
......@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
# "fp8_e5m2": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
......@@ -415,9 +425,10 @@ def init_kmp_env():
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]:
def chunk_list(lst: List[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
for i in range(0, len(lst), chunk_size):
yield lst[i:i + chunk_size]
def cdiv(a: int, b: int) -> int:
......@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
def make_ndarray_with_pad(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len: Optional[int] = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)
return padded_x
def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def async_tensor_h2d(
......@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
......@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
**kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
......@@ -9,4 +9,4 @@ except Exception as e:
stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.2"
__version__ = "0.5.3.post1"
......@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
......
......@@ -2,7 +2,8 @@ import dataclasses
import gc
import time
import warnings
from collections import defaultdict
import weakref
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
......@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
......@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import (
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
......@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
@dataclasses.dataclass(frozen=True)
@dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
"""
This base class contains metadata needed for the base model forward pass
......@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
return cls(**tensor_dict)
@dataclasses.dataclass(frozen=True)
@dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"""
Used by the ModelRunner.
......@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return cls(**tensor_dict)
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@dataclass
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0
# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)
# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)
# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)
# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
def __init__(self,
runner: "GPUModelRunnerBase",
finished_requests_ids: Optional[List[str]] = None):
super().__init__()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self.per_seq_compute_fns = [
self._compute_lens,
self._compute_for_prefix_cache_hit,
self._compute_for_sliding_window,
self._compute_lora_input,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self.per_seq_group_compute_fns = [
self._compute_prompt_adapter_input,
self._compute_multi_modal_input,
]
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.scheduler_config = self.runner.scheduler_config
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
# Engine/Model configurations.
self.chunked_prefill_enabled = (
self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled)
if self.sliding_window is not None:
self.sliding_window_blocks = (
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
token_chunk_size = seq_group_metadata.token_chunk_size
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
# Compute tokens.
if inter_data.is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx] = tokens
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
def _compute_for_prefix_cache_hit(
self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums = inter_data.computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block = 0
sliding_seq_len = inter_data.seq_lens[seq_idx]
if not inter_data.is_prompt and self.sliding_window is not None:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block = self.sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
sliding_seq_len = min(
inter_data.seq_lens[seq_idx],
self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_block += 1
else:
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
self.sliding_window)
inter_data.curr_sliding_window_blocks[
seq_idx] = curr_sliding_window_block
inter_data.seq_lens[seq_idx] = sliding_seq_len
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if not self.enable_lora:
return
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if not self.enable_prompt_adapter:
return
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
return
# We expect only one sequence in the group when is_prompt=True.
assert inter_data.n_seqs == 1
query_len = inter_data.query_lens[0]
inter_data.prompt_adapter_request = (
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
inter_data.prompt_adapter_index_mapping = [
prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
mm_data = seq_group_metadata.multi_modal_data
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(mm_data)
inter_data.multi_modal_inputs = mm_kwargs
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder."""
seq_ids = list(seq_group_metadata.seq_data.keys())
n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt
if is_prompt:
assert n_seqs == 1
self.decode_only = False
inter_data = self.InterDataForSeqGroup(
request_id=seq_group_metadata.request_id,
seq_ids=seq_ids,
is_prompt=is_prompt,
block_tables=seq_group_metadata.block_tables,
computed_block_nums=seq_group_metadata.computed_block_nums)
self.inter_data_list.append(inter_data)
for seq_idx in range(n_seqs):
for per_seq_fn in self.per_seq_compute_fns:
per_seq_fn(inter_data, seq_idx, seq_group_metadata)
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = flatten_2d_lists([
flatten_2d_lists(inter_data.input_tokens)
for inter_data in self.inter_data_list
])
if not input_tokens:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
input_positions = flatten_2d_lists([
flatten_2d_lists(inter_data.input_positions)
for inter_data in self.inter_data_list
])
seq_lens = []
max_decode_seq_len = 0
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
query_lens = flatten_2d_lists(
[inter_data.query_lens for inter_data in self.inter_data_list])
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids = {
data.request_id: data.seq_ids
for data in self.inter_data_list
}
batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
# Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size)
input_positions.extend([0] * cuda_graph_pad_size)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
# Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size)
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(r for data in self.inter_data_list
for r in data.lora_requests)
lora_index_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
lora_index_mapping.extend([0] * cuda_graph_pad_size)
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
prompt_adapter_mapping = None
if self.enable_prompt_adapter:
prompt_adapter_requests = set(
data.prompt_adapter_request for data in self.inter_data_list
if data.prompt_adapter_request is not None)
prompt_adapter_index_mapping = flatten_2d_lists([
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
# Multi-modal data.
multi_modal_inputs_list = [
data.multi_modal_inputs for data in self.inter_data_list
if data.multi_modal_inputs is not None
]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.runner.device)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
Helper class for shared methods between GPU model runners.
......@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
......@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
decode_seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
block_tables: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = 0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return self._model_input_cls()
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window + self.block_size -
1) // self.block_size
block_aligned_sliding_window = \
sliding_window_blocks * self.block_size
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
is_prompt = seq_group_metadata.is_prompt
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
context_len + seq_group_metadata.token_chunk_size)
if is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks = None
sliding_seq_len = seq_len
sliding_context_len = context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if (self.sliding_window is not None and not is_prompt):
curr_sliding_window_blocks = sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = seq_len % self.block_size
sliding_seq_len = min(
seq_len, block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_blocks += 1
else:
sliding_seq_len = min(seq_len, self.sliding_window)
sliding_context_len = sliding_seq_len - 1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert self.sliding_window is None, \
"Prefix caching is not supported with sliding window"
sliding_context_len = context_len
if self.attn_backend.get_name() == "flash-attn":
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if curr_sliding_window_blocks is not None:
block_table = block_table[
-curr_sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# Prefill without chunked prefill or memory profiling.
block_table = []
block_tables.append(block_table)
seq_lens.append(sliding_seq_len)
context_lens.append(sliding_context_len)
query_len = sliding_seq_len - sliding_context_len
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt:
assert len(seq_ids) == 1
num_prefills += 1
num_prefill_tokens += len(tokens)
decode_only = False
prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
num_decode_tokens += query_len
decode_seq_lens.append(sliding_seq_len)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * query_len
lora_prompt_mapping.extend(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
if is_prompt:
assert self.scheduler_config.use_v2_block_manager \
or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
seq_len = seq_data.get_len()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
paged_kv_indices.extend(block_table[:block_table_bound])
paged_kv_indptr.append(paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
batch_size = len(input_tokens)
max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0)
max_decode_seq_len = max(decode_seq_lens, default=0)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph = (
decode_only and not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
if use_captured_graph:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
logits_soft_cap = getattr(self.model_config.hf_config,
'attn_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
raise ValueError("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
device='cpu',
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
device='cpu',
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, device='cpu', dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
for seq_group_metadata in seq_group_metadata_list
}
return self._model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
@torch.inference_mode()
def profile_run(self) -> None:
......@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
......@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
......@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC):
raise NotImplementedError
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
......@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@torch.inference_mode()
@current_platform.inference_mode()
def execute_model(
self,
model_input: T,
......
......@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_seq_len,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_seq_len,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
......@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
......
import time
from typing import List, Mapping, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
......@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
......@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES = 128
class TPUModelRunner:
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
best_of: List[int]
seq_groups: List[List[int]]
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
......@@ -68,10 +110,6 @@ class TPUModelRunner:
False,
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
def load_model(self) -> None:
self.device = self.device_config.device
......@@ -85,6 +123,7 @@ class TPUModelRunner:
multimodal_config=self.multimodal_config,
lora_config=None,
)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
......@@ -153,8 +192,8 @@ class TPUModelRunner:
# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, None, t, p, num_samples)
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
def warmup_model(
self,
......@@ -183,7 +222,7 @@ class TPUModelRunner:
# Decode
start = time.time()
seq_len = 1
batch_size = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
xm.wait_device_ops()
......@@ -199,14 +238,12 @@ class TPUModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
slot_mapping: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -220,78 +257,62 @@ class TPUModelRunner:
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
input_tokens.extend(prompt_tokens)
input_positions.extend(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping.append([])
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
slot_mapping.append(slot)
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens)
# Add paddings to make the shape [batch_size, max_prompt_len] where
# max_prompt_len is smallest power of 2 that is greater than or equal
# to the maximum prompt length.
# We need the 2D input shape because the Pallas FlashAttention kernel
# does not support packed 1D inputs.
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
slot_mapping = make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.int64,
device=self.device)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device=self.device)
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
block_tables=None,
context_lens=None,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
......@@ -317,11 +338,6 @@ class TPUModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
......@@ -331,22 +347,22 @@ class TPUModelRunner:
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device=self.device)
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device=self.device)
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device=self.device)
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device=self.device)
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device=self.device)
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
......@@ -355,12 +371,7 @@ class TPUModelRunner:
block_tables=block_tables,
context_lens=context_lens,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
......@@ -412,16 +423,18 @@ class TPUModelRunner:
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device=self.device)
p = torch.tensor(p, dtype=torch.float32, device=self.device)
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, best_of
def _execute_model(
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
# Prepare inputs.
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
......@@ -430,16 +443,104 @@ class TPUModelRunner:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:], t, p, num_samples)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, best_of,
seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
"""Move input args from CPU to device and execute the model."""
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
if clone:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x = x.clone()
return x.to(self.device)
new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = _copy_to_device(arg)
elif isinstance(arg, AttentionMetadata):
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
if getattr(arg, "block_tables", None) is not None:
arg.block_tables = _copy_to_device(arg.block_tables)
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = _copy_to_device(arg.context_lens)
new_args.append(arg)
return self.model(*new_args)
num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
if is_prompt:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids = []
orig_slot_mapping = model_input.attn_metadata.slot_mapping
batch_size = model_input.input_lens.shape[0]
start_idx = 0
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
model_input.attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx]
model_input.attn_metadata.num_prefills = 1
output_token_ids = _execute_model(
model_input.token_ids[None, start_idx:end_idx],
model_input.position_ids[None, start_idx:end_idx],
model_input.attn_metadata,
model_input.input_lens[i:i + 1],
model_input.t[i:i + 1],
model_input.p[i:i + 1],
model_input.num_samples,
kv_caches,
clone=True)
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
else:
# Execute the model.
output_token_ids = _execute_model(
model_input.token_ids, model_input.position_ids,
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
......@@ -447,13 +548,13 @@ class TPUModelRunner:
zero_logprob = Logprob(0.0)
batch_idx = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
for seq_group in model_input.seq_groups:
seq_ids = seq_group
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
if is_prompt:
assert len(seq_ids) == 1
seq_id = seq_ids[0]
for i in range(best_of[batch_idx]):
for i in range(model_input.best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
......@@ -468,35 +569,6 @@ class TPUModelRunner:
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
num_steps: int = 1,
) -> List[SamplerOutput]:
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
sampler_outputs += self._execute_model([seq_group_metadata],
kv_caches)
else:
sampler_outputs = self._execute_model(seq_group_metadata_list,
kv_caches)
return [SamplerOutput(sampler_outputs)]
......@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model.eval()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
base_indicies = torch.arange(
start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = base_indicies + input_lens - 1
logits_indices = start_indicies + input_lens - 1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
......@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module):
position_ids,
kv_caches,
attn_metadata,
**(multi_modal_kwargs or {}),
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
......@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int:
def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
elif batch_size <= 8:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
......
......@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase):
class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
......@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
multimodal_config,
is_driver_worker=is_driver_worker)
self.model_runner: TPUModelRunner = TPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
......@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results.
xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH),
readonly=False)
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False)
def load_model(self):
self.model_runner.load_model()
......@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
def execute_model(
@property
def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP.
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
def prepare_worker_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
assert execute_model_req is not None
# Issue cache operations.
self.cache_swap(
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy,
execute_model_req: ExecuteModelRequest,
) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
blocks_to_swap_in = _make_src_to_dst(
execute_model_req.blocks_to_swap_in, "cpu", self.device)
blocks_to_swap_out = _make_src_to_dst(
execute_model_req.blocks_to_swap_out, self.device, "cpu")
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
self.device, self.device)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
return output
def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
assert virtual_engine == 0
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
if blocks_to_swap_in:
# Swap from CPU to TPU.
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_in, "cpu", self.device)
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if blocks_to_swap_out:
# Swap from TPU to CPU.
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_out, self.device, "cpu")
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
# Issue cache operations.
if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_in
if src_indices.numel() > 0:
# Swap from CPU to TPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_out
if src_indices.numel() > 0:
# Swap from TPU to CPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if worker_input.blocks_to_copy is not None:
src_indices, dst_indices = worker_input.blocks_to_copy
if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
(src_indices, dst_indices))
def _make_src_to_dst(
......
......@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache.
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment