Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
...@@ -5,7 +5,18 @@ from typing import Optional ...@@ -5,7 +5,18 @@ from typing import Optional
import torch import torch
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass @dataclass
...@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata: ...@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k: torch.Tensor = None top_k: torch.Tensor = None
top_p: torch.Tensor = None top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph. # Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None all_greedy: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla # Generator not supported by xla
generators: dict[int, generators: dict[int,
...@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata: ...@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids = None bad_words_token_ids = None
indices_do_sample: torch.Tensor = None indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod @classmethod
def from_sampling_metadata( def from_input_batch(
cls, metadata: SamplingMetadata, cls, input_batch: InputBatch,
padded_do_sample_indices: torch.Tensor, num_do_sample: int, indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
device: torch.device) -> "TPUSupportedSamplingMetadata":
""" """
Create an XLA-frienly SamplingMetadata structure. Do so by first Copy sampling tensors slices from `input_batch` to on device tensors.
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
recompilation is not triggered for optional values (None/torch.Tensor). slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
In order to handle different sizes for the params that range from 1 up also reuses the on-device persistent tensors managed in `input_batch`
to `max_num_seqs`, pad tensors to the closest pre-compiled shape. to reduce waste.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape. `indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] We expect sampling params tensors to be padded to the same fixed shape.
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
""" """
metadata = cls._validate_sampling_metadata(metadata) num_reqs = input_batch.num_reqs
# NOTE we have to initialize default tensor-based params first and padded_num_reqs = len(indices_do_sample)
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices) def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
do_argmax = torch.tensor(metadata.all_greedy, fill_val) -> torch.Tensor:
dtype=torch.bool, # Copy slice from CPU to corresponding TPU pre-allocated tensor.
device=device) # Pad value is the default one.
new_metadata = cls.get_default_sampling_params(num_samples, device, cpu_tensor[num_reqs:padded_num_reqs] = fill_val
indices_do_sample=\ # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
padded_do_sample_indices, tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
do_argmax=do_argmax
) # NOTE NickLucche The sync CPU-TPU graph we produce here must be
supported_params = \ # consistent. We can't have flags to skip copies or we'll end up
TPUSupportedSamplingMetadata._get_default_params_values() # recompiling.
# Copy input non-None values into `new_metadata` fixed-sized tensors. copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
for p_name in supported_params: DEFAULT_SAMPLING_PARAMS["temperature"])
old_val = getattr(metadata, p_name) # TODO Temporarily disabled until sampling options are enabled
new_val = getattr(new_metadata, p_name) # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
if isinstance(old_val, torch.Tensor): # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
new_val[:num_do_sample] = old_val copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
setattr(new_metadata, p_name, new_val) DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
return new_metadata
@classmethod # Slice persistent device tensors to a fixed pre-compiled padded shape.
def get_default_sampling_params( return cls(
cls, temperature=input_batch.temperature[:padded_num_reqs],
num_samples: int, # Scalar tensor for xla-friendly tracing.
device: torch.device, all_greedy=torch.tensor(input_batch.all_greedy,
indices_do_sample=None, dtype=torch.bool,
do_argmax=None) -> "TPUSupportedSamplingMetadata": device=input_batch.device),
# As sampling happens on a single traced graph, options # TODO enable more and avoid returning None values
# are "disabled" by having them evaluate to an Identity op. top_p=None, # input_batch.top_p[:padded_num_reqs],
# Note that initialization is dependent on num_samples. top_k=None, # input_batch.top_k[:padded_num_reqs],
sampling_metadata_disable_value = \ min_p=input_batch.min_p[:padded_num_reqs],
TPUSupportedSamplingMetadata._get_default_params_values() generators=input_batch.generators,
init_kwargs = dict() indices_do_sample=indices_do_sample)
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle import pickle
from types import FunctionType
from typing import Any, Optional from typing import Any, Optional
import cloudpickle
import torch import torch
from msgspec import msgpack from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1 CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2 CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
class MsgpackEncoder: class MsgpackEncoder:
...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any: ...@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any: ...@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return torch.from_numpy(pickle.loads(data)) return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE: if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data) return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported") raise NotImplementedError(f"Extension type code {code} is not supported")
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
class EagleProposer:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
seq_lens = target_positions[last_token_indices] + 1
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=target_hidden_states,
positions=target_positions,
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]
positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Compute the slot mapping.
block_numbers = positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
@staticmethod
def prepare_inputs(
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
BLOCK_SIZE=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
self.model = DummyEagleModel()
self.model.get_input_embeddings = target_model.get_input_embeddings
self.model.compute_logits = target_model.compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class DummyEagleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
input_embeddings = self.get_input_embeddings(input_ids)
return hidden_states + input_embeddings # Dummy return.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs
@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.reset()
return copied
def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
class SpecDecodingMetrics:
def __init__(self):
self.reset()
def reset(self):
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)
def log(self):
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
100 if num_draft_tokens > 0 else float("nan"))
logger.info(
"SpecDecoding metrics: "
"Draft acceptance rate: %.1f%%, "
"Accepted: %d tokens, "
"Drafted: %d tokens",
draft_acceptance_rate,
num_accepted_tokens,
num_draft_tokens,
)
self.reset()
...@@ -4,15 +4,27 @@ from typing import Optional ...@@ -4,15 +4,27 @@ from typing import Optional
import numpy as np import numpy as np
from numba import jit from numba import jit
from vllm.config import VllmConfig
class NgramProposer: class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
self.max_n = vllm_config.speculative_config.prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
def propose( def propose(
self, self,
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
min_n: int,
max_n: int,
k: int,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern """Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n matching in the context. The function finds matches of the last n
...@@ -22,17 +34,12 @@ class NgramProposer: ...@@ -22,17 +34,12 @@ class NgramProposer:
Args: Args:
context_token_ids: Numpy array of token IDs representing the context_token_ids: Numpy array of token IDs representing the
context sequence. context sequence.
min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns: Returns:
np.ndarray: The sequence of tokens that followed np.ndarray: The sequence of tokens that followed
the matched n-gram in the context. the matched n-gram in the context.
None: If no matching n-gram pattern is found. None: If no matching n-gram pattern is found.
Example: Example:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4: k = 4:
...@@ -44,12 +51,16 @@ class NgramProposer: ...@@ -44,12 +51,16 @@ class NgramProposer:
we only have three tokens after the match. we only have three tokens after the match.
""" """
# TODO(woosuk): Optimize this. # TODO(woosuk): Optimize this.
for n in range(max_n, min_n - 1, -1): for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, k) result = _find_subarray_kmp(context_token_ids, n, self.k)
if result is not None: if result is not None:
return result return result
return None return None
def load_model(self, *args, **kwargs):
# No model to load.
pass
@jit(nopython=True) @jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import multiprocessing import multiprocessing
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -57,13 +57,13 @@ class StructuredOutputManager: ...@@ -57,13 +57,13 @@ class StructuredOutputManager:
raise ValueError( raise ValueError(
f"Unsupported structured output backend: {backend_name}") f"Unsupported structured output backend: {backend_name}")
grammar: Future[StructuredOutputGrammar] = self.executor.submit( grammar = self.executor.submit(self._async_create_grammar, request)
self._async_create_grammar, request, self.backend)
request.structured_output_request.grammar = grammar # type: ignore[assignment] request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _async_create_grammar( def _async_create_grammar(
self, request: Request, self,
backend: StructuredOutputBackend) -> StructuredOutputGrammar: request: Request,
) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr] key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client, # Note that the request was validated in the engine core client,
......
...@@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend): ...@@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group.ping() tokenizer_group.ping()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size() self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
...@@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend): ...@@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions, def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar: grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar( self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec) request_type, grammar_spec, self.disable_any_whitespace)
ll_matcher = llguidance.LLMatcher( ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, self.ll_tokenizer,
...@@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar): ...@@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
def serialize_guidance_grammar(request_type: StructuredOutputOptions, def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str) -> str: grammar_spec: str,
disable_any_whitespace: bool = False) -> str:
if request_type == StructuredOutputOptions.JSON: if request_type == StructuredOutputOptions.JSON:
# TODO: make whitespace_flexible configurable
return llguidance.LLMatcher.grammar_from_json_schema( return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec, defaults={ grammar_spec,
"whitespace_flexible": True, defaults={
"whitespace_flexible": not disable_any_whitespace,
}) })
elif request_type == StructuredOutputOptions.JSON_OBJECT: elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema( return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}', defaults={ '{"type": "object"}',
"whitespace_flexible": True, defaults={
"whitespace_flexible": not disable_any_whitespace,
}) })
else: else:
if request_type == StructuredOutputOptions.REGEX: if request_type == StructuredOutputOptions.REGEX:
......
...@@ -42,12 +42,15 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -42,12 +42,15 @@ class XgrammarBackend(StructuredOutputBackend):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try: try:
encoded_vocab = [ if tokenizer.is_tekken:
token for token, _ in sorted( encoded_vocab = tokenizer._vocab
tokenizer.get_vocab().items(), else:
key=lambda x: x[1], encoded_vocab = [
) token for token, _ in sorted(
] tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None stop_token_ids = None
if hasattr( if hasattr(
tokenizer, tokenizer,
...@@ -62,7 +65,8 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -62,7 +65,8 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer_info = xgr.TokenizerInfo( # type: ignore tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab, encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK, vocab_type=xgr.VocabType.RAW
if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
add_prefix_space=True, add_prefix_space=True,
...@@ -80,7 +84,9 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -80,7 +84,9 @@ class XgrammarBackend(StructuredOutputBackend):
ctx = self.compiler.compile_json_schema( ctx = self.compiler.compile_json_schema(
grammar_spec, any_whitespace=not self.disable_any_whitespace) grammar_spec, any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.JSON_OBJECT: elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar() ctx = self.compiler.compile_json_schema(
'{"type": "object"}',
any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.GRAMMAR: elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec) ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX: elif request_type == StructuredOutputOptions.REGEX:
......
...@@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: ...@@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
if "pattern" in obj: if "pattern" in obj:
return True return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and any(
key in obj key in obj
......
...@@ -105,7 +105,7 @@ class BackgroundProcHandle: ...@@ -105,7 +105,7 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any], process_kwargs: dict[Any, Any],
): ):
context = get_mp_context() context = get_mp_context()
reader, writer = context.Pipe(duplex=False) self.reader, writer = context.Pipe(duplex=False)
assert ("ready_pipe" not in process_kwargs assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs and "input_path" not in process_kwargs
...@@ -115,14 +115,17 @@ class BackgroundProcHandle: ...@@ -115,14 +115,17 @@ class BackgroundProcHandle:
process_kwargs["output_path"] = output_path process_kwargs["output_path"] = output_path
# Run busy loop in background process. # Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs) self.proc = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc, self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path) input_path, output_path)
self.proc.start() self.proc.start()
def wait_for_startup(self):
# Wait for startup. # Wait for startup.
if reader.recv()["status"] != "READY": if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. " raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.") "See root cause above.")
def shutdown(self): def shutdown(self):
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# Datastructures defining an input batch # Datastructures defining an input batch
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, cast from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
...@@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable ...@@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
@dataclass @dataclass
class CachedRequestState: class CachedRequestState:
...@@ -29,7 +26,7 @@ class CachedRequestState: ...@@ -29,7 +26,7 @@ class CachedRequestState:
prompt_token_ids: list[int] prompt_token_ids: list[int]
prompt: Optional[str] prompt: Optional[str]
mm_inputs: list[MultiModalKwargs] mm_inputs: list[MultiModalKwargs]
mm_positions: list["PlaceholderRange"] mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams sampling_params: SamplingParams
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
...@@ -42,9 +39,18 @@ class CachedRequestState: ...@@ -42,9 +39,18 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids) return self.num_prompt_tokens + len(self.output_token_ids)
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch: class InputBatch:
......
...@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention ...@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -25,16 +24,18 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality ...@@ -25,16 +24,18 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, LazyLoader, cdiv, GiB_bytes, LayerBlockType, LazyLoader, cdiv,
is_pin_memory_available) check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheSpec) KVCacheConfig, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.spec_decode.utils import is_spec_decode_supported
...@@ -42,6 +43,8 @@ from vllm.v1.utils import bind_kv_cache ...@@ -42,6 +43,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
...@@ -70,6 +73,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -70,6 +73,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
model_config = self.model_config model_config = self.model_config
cache_config = self.cache_config cache_config = self.cache_config
scheduler_config = self.scheduler_config scheduler_config = self.scheduler_config
...@@ -106,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -106,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size() self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.head_size, self.head_size,
...@@ -130,13 +138,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -130,13 +138,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
) )
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
...@@ -151,18 +159,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -151,18 +159,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.drafter = NgramProposer() if self.speculative_config.method == "ngram":
# Trigger Numba JIT compilation for N-gram proposer. self.drafter = NgramProposer(self.vllm_config)
# This usually takes less than 1 second. elif self.speculative_config.method == "eagle":
self.drafter.propose( self.drafter = EagleProposer(self.vllm_config,
np.zeros(1024, dtype=np.int32), self.device) # type: ignore
self.speculative_config.prompt_lookup_min, else:
self.speculative_config.prompt_lookup_max, raise ValueError("Unknown speculative decoding method: "
self.speculative_config.num_speculative_tokens, f"{self.speculative_config.method}")
)
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
...@@ -223,6 +228,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -223,6 +228,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
...@@ -671,7 +679,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -671,7 +679,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# use two kernels for cascade attention. Let's imagine: # use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D] # Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D] # Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
# If we use [A, B, C, D] as the common prefix for Request 1-3, # If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel, # then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not # and the second kernel will get an empty input. While this is not
...@@ -689,7 +697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -689,7 +697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads, num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
use_alibi=False, # FIXME use_alibi=self.use_alibi,
use_sliding_window=self.window_size is not None, use_sliding_window=self.window_size is not None,
num_sms=self.num_sms, num_sms=self.num_sms,
) )
...@@ -861,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -861,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **batched_mm_inputs)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs: for output in curr_group_outputs:
encoder_outputs.append(output) encoder_outputs.append(output)
...@@ -1085,8 +1098,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1085,8 +1098,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize. # the requests one by one. Optimize.
for i, generator in self.input_batch.generators.items(): discard_sampled_tokens_req_indices = []
req_id = self.input_batch.req_ids[i] for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
...@@ -1094,7 +1107,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1094,7 +1107,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Ignore the sampled token for partial prefills. # Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4) generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here. # NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point. # Move as many CPU operations as possible before this sync point.
...@@ -1117,13 +1135,83 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1117,13 +1135,83 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size) sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.use_spec_decode: if not self.use_spec_decode:
# Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: elif self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids( spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata) valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids, draft_probs = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del draft_probs
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
...@@ -1159,11 +1247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1159,11 +1247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
end_idx = start_idx + num_sampled_ids end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose( drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx], self.input_batch.token_ids_cpu[i, :end_idx])
self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0: if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([]) draft_token_ids.append([])
else: else:
...@@ -1181,10 +1265,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1181,10 +1265,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.scheduler_config, self.scheduler_config,
self.lora_config, self.lora_config,
self.device) self.device)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load) time_after_load - time_before_load)
def _get_prompt_logprobs_dict( def _get_prompt_logprobs_dict(
...@@ -1425,9 +1512,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1425,9 +1512,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
max_tokens_by_modality_dict = ( max_tokens_by_modality_dict = self.mm_registry \
MULTIMODAL_REGISTRY. .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
dummy_data_modality, max_tokens_per_mm_item = max( dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) max_tokens_by_modality_dict.items(), key=lambda item: item[1])
...@@ -1459,24 +1545,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1459,24 +1545,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_budget, max_num_mm_items, dummy_data_modality) encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling( dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config, model_config=self.model_config,
seq_len=self.max_num_tokens, seq_len=self.max_num_tokens,
mm_registry=self.mm_registry, mm_counts={
) dummy_data_modality: 1
dummy_mm_data = dummy_request_data.multi_modal_data },
if not isinstance(dummy_mm_data, MultiModalKwargs): ).multi_modal_data
# TODO: Delete this check once input mapper is fully removed.
raise RuntimeError(
"Legacy input mapper is not supported in V1")
# Dummy data definition may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
dummy_mm_item = dummy_mm_data.get_item(
modality=dummy_data_modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items) [dummy_mm_kwargs] * max_num_mm_items)
...@@ -1486,12 +1561,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1486,12 +1561,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run multimodal encoder. # Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings( dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs) **batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number " sanity_check_mm_encoder_outputs(
f"of multimodal data items: {max_num_mm_items}, got " dummy_encoder_outputs,
f"{len(dummy_encoder_outputs)=} instead. This is most likely " expected_num_items=max_num_mm_items,
"due to the 'get_multimodal_embeddings' method of the model " )
"not implemented correctly.")
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
...@@ -1562,7 +1636,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1562,7 +1636,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to # different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here. # the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
...@@ -1601,12 +1675,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1601,12 +1675,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# cross-attention # cross-attention
assert isinstance(attn_module, Attention) assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec( if attn_module.sliding_window is not None:
block_size=block_size, kv_cache_spec[layer_name] = SlidingWindowSpec(
num_kv_heads=attn_module.num_kv_heads, block_size=block_size,
head_size=attn_module.head_size, num_kv_heads=attn_module.num_kv_heads,
dtype=self.kv_cache_dtype, head_size=attn_module.head_size,
use_mla=use_mla) dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.
......
...@@ -83,9 +83,9 @@ class Worker(WorkerBase): ...@@ -83,9 +83,9 @@ class Worker(WorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up() allocator.wake_up(tags)
def init_device(self): def init_device(self):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
...@@ -269,6 +269,20 @@ class Worker(WorkerBase): ...@@ -269,6 +269,20 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running. # worker will always be healthy as long as it's running.
return return
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def init_worker_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import bisect
import time import time
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch from unittest.mock import patch
...@@ -16,7 +17,6 @@ from vllm.attention.backends.abstract import AttentionType ...@@ -16,7 +17,6 @@ from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
...@@ -24,12 +24,11 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality ...@@ -24,12 +24,11 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasAttentionBackend,
PallasMetadata) PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput) ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
...@@ -37,6 +36,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler ...@@ -37,6 +36,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -75,11 +76,15 @@ class TPUModelRunner: ...@@ -75,11 +76,15 @@ class TPUModelRunner:
parallel_config = self.parallel_config parallel_config = self.parallel_config
self.device = device self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
self._update_num_xla_graphs("init")
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
self.is_multimodal_model = model_config.is_multimodal_model self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
...@@ -87,7 +92,9 @@ class TPUModelRunner: ...@@ -87,7 +92,9 @@ class TPUModelRunner:
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs # InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
# Model-related. # Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type( self.num_attn_layers = model_config.get_num_layers_by_block_type(
...@@ -99,7 +106,6 @@ class TPUModelRunner: ...@@ -99,7 +106,6 @@ class TPUModelRunner:
self.hidden_size = model_config.get_hidden_size() self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL) # TODO: Support M-RoPE (e.g, Qwen2-VL)
...@@ -108,6 +114,7 @@ class TPUModelRunner: ...@@ -108,6 +114,7 @@ class TPUModelRunner:
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config, model_config=model_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
) )
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
...@@ -147,11 +154,8 @@ class TPUModelRunner: ...@@ -147,11 +154,8 @@ class TPUModelRunner:
dtype=torch.int64, dtype=torch.int64,
device="cpu") device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros( self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req), (self.max_num_tokens, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype, dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu") device="cpu")
...@@ -170,6 +174,35 @@ class TPUModelRunner: ...@@ -170,6 +174,35 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1]. # Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens # Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
total_cached_graphs = xr.get_num_cached_compilation_graph()
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
if new_compiled_graphs == 0:
return
logger.info("Add new %d compiled XLA graphs due to %s",
new_compiled_graphs, case_str)
self.num_xla_graphs += new_compiled_graphs
def _verify_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected during {}."
" num_xla_graphs = {} curr_cached_graph = {}".format(
case_str, self.num_xla_graphs, curr_cached_graph))
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
...@@ -279,9 +312,6 @@ class TPUModelRunner: ...@@ -279,9 +312,6 @@ class TPUModelRunner:
req_data.num_computed_tokens) req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index) req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
...@@ -300,9 +330,6 @@ class TPUModelRunner: ...@@ -300,9 +330,6 @@ class TPUModelRunner:
if removed_req_indices: if removed_req_indices:
self.input_batch.condense(removed_req_indices) self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
...@@ -322,17 +349,25 @@ class TPUModelRunner: ...@@ -322,17 +349,25 @@ class TPUModelRunner:
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention) assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec( if attn_module.sliding_window is not None:
block_size=block_size, kv_cache_spec[layer_name] = SlidingWindowSpec(
num_kv_heads=attn_module.num_kv_heads, block_size=block_size,
head_size=attn_module.head_size, num_kv_heads=attn_module.num_kv_heads,
dtype=attn_module.dtype, head_size=attn_module.head_size,
use_mla=False, dtype=attn_module.dtype,
) sliding_window=attn_module.sliding_window,
use_mla=False,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=False,
)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.
...@@ -428,7 +463,7 @@ class TPUModelRunner: ...@@ -428,7 +463,7 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU. # Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_token_len( padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens) self.num_tokens_paddings, total_num_scheduled_tokens)
# Zero out to avoid spurious values from prev iteration (last cp chunk) # Zero out to avoid spurious values from prev iteration (last cp chunk)
self.input_ids_cpu[ self.input_ids_cpu[
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
...@@ -511,6 +546,11 @@ class TPUModelRunner: ...@@ -511,6 +546,11 @@ class TPUModelRunner:
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **batched_mm_inputs)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs: for output in curr_group_outputs:
encoder_outputs.append(output) encoder_outputs.append(output)
...@@ -579,7 +619,6 @@ class TPUModelRunner: ...@@ -579,7 +619,6 @@ class TPUModelRunner:
# Prepare inputs # Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
if self.is_multimodal_model: if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
...@@ -597,14 +636,12 @@ class TPUModelRunner: ...@@ -597,14 +636,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids input_ids = self.input_ids
inputs_embeds = None inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape # NOTE (NickLucche) here we sync with TPU: sampling params tensors
# mismatch in pre-processing, it will trigger a small recompilation # are copied to device in chunks of pre-compiled padded shape to
# of the code thus far. Forward graph remains untouched. # avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices, from_input_batch(self.input_batch, logits_indices)
num_reqs, self.device)
# Run the decoder # Run the decoder
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
...@@ -621,6 +658,7 @@ class TPUModelRunner: ...@@ -621,6 +658,7 @@ class TPUModelRunner:
# Update the cache state concurrently. Code above will not block until # Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes # we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
...@@ -636,6 +674,10 @@ class TPUModelRunner: ...@@ -636,6 +674,10 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4) generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
assert all( assert all(
req_id is not None for req_id in req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
...@@ -649,11 +691,19 @@ class TPUModelRunner: ...@@ -649,11 +691,19 @@ class TPUModelRunner:
if max_gen_len == 1: if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist() valid_sampled_token_ids = selected_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Append sampled tokens
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0] token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id) req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1 self.input_batch.num_tokens[i] += 1
else: else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist() gen_lens = valid_mask.sum(dim=1).tolist()
...@@ -676,12 +726,11 @@ class TPUModelRunner: ...@@ -676,12 +726,11 @@ class TPUModelRunner:
logprobs=None, logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
) )
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up. # Check there are no new graphs compiled - all the graphs should be
if self.check_recompilation and not self.enforce_eager: # captured and compiled during warm up.
curr_cached_graph = xr.get_num_cached_compilation_graph() self._verify_num_xla_graphs("execute_model")
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
return model_runner_output return model_runner_output
def load_model(self) -> None: def load_model(self) -> None:
...@@ -761,10 +810,11 @@ class TPUModelRunner: ...@@ -761,10 +810,11 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(input_ids=input_ids, out = self.model(input_ids=input_ids,
positions=position_ids, positions=position_ids,
kv_caches=kv_caches, kv_caches=kv_caches,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
...@@ -772,63 +822,54 @@ class TPUModelRunner: ...@@ -772,63 +822,54 @@ class TPUModelRunner:
logger.info("Compiling the model with different input shapes.") logger.info("Compiling the model with different input shapes.")
start = time.perf_counter() start = time.perf_counter()
num_tokens = 16 for num_tokens in self.num_tokens_paddings:
while True:
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens) self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step() xm.mark_step()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
logger.info("Compiling sampling with different input shapes.") logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter() start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size() hsize = self.model_config.get_hidden_size()
device = self.device device = self.device
# Compile sampling step for different model+sampler outputs in bucketed # Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine. # n_tokens x max_num_reqs. Graph is really small so this is fine.
while True: for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize), dummy_hidden = torch.randn((num_tokens, hsize),
device=device, device=device,
dtype=torch.bfloat16) dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True: while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros( indices = torch.zeros(
num_reqs_to_sample, num_reqs_to_sample,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\ sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices, from_input_batch(self.input_batch, indices)
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample) num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta) out = self.model.sample_from_hidden(dummy_hidden,
xm.mark_step() sampling_meta)
if num_reqs_to_sample >= self.max_num_reqs: out = out.cpu()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break break
num_reqs_to_sample *= 2 # Make sure to compile the `max_num_reqs` upper-limit case
if num_tokens >= self.max_num_tokens: num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
break num_reqs_to_sample + 1, self.max_num_reqs)
num_tokens *= 2
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be self._update_num_xla_graphs("sampling")
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
...@@ -856,12 +897,11 @@ class TPUModelRunner: ...@@ -856,12 +897,11 @@ class TPUModelRunner:
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape, tpu_kv_cache = torch.zeros(kv_cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) kv_caches[layer_name] = tpu_kv_cache
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -888,7 +928,7 @@ class ModelWrapperV1(nn.Module): ...@@ -888,7 +928,7 @@ class ModelWrapperV1(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]], kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model. """Executes the forward pass of the model.
...@@ -923,10 +963,9 @@ class ModelWrapperV1(nn.Module): ...@@ -923,10 +963,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states = \ sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample] hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states) logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler. # Optimized greedy sampling branch, tracing both paths in a single pass
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way. # NOTE all_greedy is a scalar, this is just an optimized if/else.
# NOTE do_argmax is a scalar, this is just an optimized if/else. out_tokens = torch.where(sampling_metadata.all_greedy,
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True), torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\ self.sample(logits, sampling_metadata)\
.sampled_token_ids) .sampled_token_ids)
...@@ -949,12 +988,50 @@ def _get_padded_number(n: int, multiple: int) -> int: ...@@ -949,12 +988,50 @@ def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple return ((n + multiple - 1) // multiple) * multiple
def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit) return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
then increase the padding size by padding_gap.
"""
paddings = []
num = min_token_size
if padding_gap == 0:
logger.info("Using exponential paddings:")
while num <= max_token_size:
logger.info(" %d", num)
paddings.append(num)
num *= 2
else:
logger.info("Using incremental paddings:")
while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num)
num *= 2
num //= 2
while num < max_token_size:
num += padding_gap
logger.info(" %d", num)
paddings.append(num)
return paddings
def _get_padded_token_len(paddings: list[int], x: int) -> int:
"""Return the first element in paddings list greater or equal to x.
"""
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]
...@@ -18,7 +18,7 @@ from vllm.logger import init_logger ...@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
...@@ -66,20 +66,30 @@ class TPUWorker: ...@@ -66,20 +66,30 @@ class TPUWorker:
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None self.profiler = None
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler # For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0. # server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s", logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir) self.profile_dir)
self.profiler = xp.start_server(9012)
if self.model_config.seed is None: if self.model_config.seed is None:
self.model_config.seed = 0 self.model_config.seed = 0
def init_device(self): def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype) torch.set_default_dtype(self.model_config.dtype)
...@@ -101,17 +111,24 @@ class TPUWorker: ...@@ -101,17 +111,24 @@ class TPUWorker:
# Increase the cache size limit, which is the maximum number of # Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled. # dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and # TODO (NickLucche) On gsm we compile 80+ graphs.
# 30-40 graphs for decode. 128 is an arbitrary safe number. # Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128 torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation. # Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks # NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs. # can have slightly different XLA graphs.
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
rank = xr.global_ordinal() rank = xr.global_ordinal()
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
f"tp{world_size}_rank{rank}") # Consequently, changes in optimization flags, which affect compilation
xr.initialize_cache(per_rank_path, readonly=False) # results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(self.vllm_config, self.device) self.model_runner = TPUModelRunner(self.vllm_config, self.device)
...@@ -120,17 +137,18 @@ class TPUWorker: ...@@ -120,17 +137,18 @@ class TPUWorker:
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec() kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items(): for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec): if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass # Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``. # it by reference, rather by specializing on the value ``None``.
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device) tpu_kv_cache = torch.tensor([],
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device) dtype=dtype,
device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) kv_caches[layer_name] = tpu_kv_cache
else: else:
raise NotImplementedError raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'")
runner_kv_caches: list[torch.Tensor] = [] runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache( bind_kv_cache(
...@@ -150,7 +168,13 @@ class TPUWorker: ...@@ -150,7 +168,13 @@ class TPUWorker:
# intermediate activations. # intermediate activations.
m = xm.get_memory_info(self.device) m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"] total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations. current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling. # Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size * usable_memory_size = int(total_memory_size *
...@@ -168,9 +192,11 @@ class TPUWorker: ...@@ -168,9 +192,11 @@ class TPUWorker:
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.rank < 1: if self.rank < 1:
if self.profiler is None: if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
if is_start: if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir) xp.start_trace(self.profile_dir)
else: else:
xp.stop_trace() xp.stop_trace()
......
# SPDX-License-Identifier: Apache-2.0
import torch
def sanity_check_mm_encoder_outputs(
mm_embeddings: object,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
:meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
...@@ -28,4 +28,13 @@ def _prev_minor_version_was(version_str): ...@@ -28,4 +28,13 @@ def _prev_minor_version_was(version_str):
return True return True
# Note - this won't do the right thing when we release 1.0! # Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
...@@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free, self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None ) if needs_attn_backend else None
# Multi-modal data support # Multi-modal data support
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A CPU worker class.""" """A CPU worker class."""
import os
from typing import Dict, List, Optional, Set, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type
import torch import torch
...@@ -67,6 +68,7 @@ class CPUCacheEngine: ...@@ -67,6 +68,7 @@ class CPUCacheEngine:
cache_config.cache_dtype, cache_config.cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free, self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) )
# Initialize the cache. # Initialize the cache.
...@@ -106,7 +108,7 @@ class CPUCacheEngine: ...@@ -106,7 +108,7 @@ class CPUCacheEngine:
num_layers = model_config.get_num_layers(parallel_config) num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_layers * (key_cache_block + value_cache_block) total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto": if cache_dtype == "auto":
dtype = model_config.dtype dtype = model_config.dtype
...@@ -139,6 +141,8 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -139,6 +141,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
...@@ -217,6 +221,10 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -217,6 +221,10 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret: if ret:
logger.info(ret) logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.init_distributed_environment() self.init_distributed_environment()
# Set random seed. # Set random seed.
......
...@@ -1145,8 +1145,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1145,8 +1145,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load) time_after_load - time_before_load)
if self.prompt_adapter_config: if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
...@@ -1244,6 +1244,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1244,6 +1244,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs) self._dummy_run(max_num_batched_tokens, max_num_seqs)
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
assert num_loras > 0
assert self.lora_manager is not None
dummy_lora_requests: list[LoRARequest] = []
with self.lora_manager.dummy_lora_cache():
for idx in range(num_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
return dummy_lora_requests
def _remove_dummy_loras(self):
# Remove dummy loras.
assert self.lora_manager is not None
self.remove_all_loras()
def _dummy_run(self, def _dummy_run(self,
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None: max_num_seqs: int = 1) -> None:
...@@ -1253,28 +1276,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1253,28 +1276,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
# This represents the maximum number of different requests # This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory # that will have unique loras, and therefore the max amount of
# consumption create dummy lora request copies from the lora request # memory consumption. Create dummy lora request copies from the
# passed in, which contains a lora from the lora warmup path. # lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config: if self.lora_config:
assert self.lora_manager is not None dummy_lora_requests = self._add_dummy_loras(
with self.lora_manager.dummy_lora_cache(): self.lora_config.max_loras)
for idx in range(self.lora_config.max_loras): assert len(dummy_lora_requests) == self.lora_config.max_loras
lora_id = idx + 1 dummy_lora_requests_per_seq = [
dummy_lora_request = LoRARequest( dummy_lora_requests[idx % len(dummy_lora_requests)]
lora_name=f"warmup_{lora_id}", for idx in range(max_num_seqs)
lora_int_id=lora_id, ]
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the # Profile memory usage with max_num_sequences sequences and the
# total number of tokens equal to max_num_batched_tokens. # total number of tokens equal to max_num_batched_tokens.
...@@ -1356,9 +1371,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1356,9 +1371,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
if self.lora_config: if self.lora_config:
# Remove dummy loras. self._remove_dummy_loras()
assert self.lora_manager is not None
self.remove_all_loras()
return return
def remove_all_loras(self): def remove_all_loras(self):
...@@ -1481,6 +1495,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1481,6 +1495,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
dummy_lora_id: Optional[int] = None
dummy_lora_request: LoRARequest = []
if self.lora_config:
# The goal is to capture the LoRA kernels in cuda graphs.
# for this purpose, as single dummy lora is sufficient.
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
assert len(dummy_lora_requests) == 1
dummy_lora_request = dummy_lora_requests[0]
dummy_lora_id = dummy_lora_request.lora_int_id
with self.attn_state.graph_capture(max_batch_size), graph_capture( with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context: self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
...@@ -1505,10 +1529,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1505,10 +1529,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
attn_metadata.enable_kv_scales_calculation = False attn_metadata.enable_kv_scales_calculation = False
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size, **dict(index_mapping=[dummy_lora_id] * batch_size,
prompt_mapping=[0] * batch_size, prompt_mapping=[dummy_lora_id] * batch_size,
is_prefill=False)) is_prefill=False))
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set([dummy_lora_request]),
lora_mapping)
if self.prompt_adapter_config: if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping( prompt_adapter_mapping = PromptAdapterMapping(
...@@ -1564,6 +1589,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1564,6 +1589,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
graph_runner) graph_runner)
if self.lora_config:
self._remove_dummy_loras()
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment