Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -120,7 +120,7 @@ class Processor: ...@@ -120,7 +120,7 @@ class Processor:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
supported_backends = ["xgrammar"] supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
engine_level_backend = self.decoding_config.guided_decoding_backend engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends: if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is " raise ValueError(f"Only {supported_backends} structured output is "
...@@ -173,7 +173,6 @@ class Processor: ...@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists. # 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash, return_mm_hashes=self.use_hash,
......
...@@ -62,14 +62,11 @@ class Executor(ExecutorBase): ...@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, )) args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model") self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory") output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum return output
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)
def get_kv_cache_specs(self) -> list[KVCacheSpec]: def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec") output = self.collective_rpc("get_kv_cache_spec")
return output return output
...@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor): ...@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> int: # in bytes def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0, # same as determine_num_available_blocks in v0,
# we need to get the min across all ranks. # we need to get the min across all ranks.
memory = super().determine_available_memory() memory = super().determine_available_memory()
...@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): ...@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group = get_world_group().cpu_group cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item() return [memory_tensor.item()]
...@@ -5,6 +5,7 @@ import pickle ...@@ -5,6 +5,7 @@ import pickle
import signal import signal
import sys import sys
import time import time
import traceback
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
...@@ -370,6 +371,9 @@ class WorkerProc: ...@@ -370,6 +371,9 @@ class WorkerProc:
func = partial(cloudpickle.loads(method), self.worker) func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs) output = func(*args, **kwargs)
except Exception as e: except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e)) (WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e) logger.exception("WorkerProc hit an exception: %s", exc_info=e)
......
...@@ -11,7 +11,7 @@ logger = init_logger(__name__) ...@@ -11,7 +11,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class KVCacheSpecBase: class KVCacheSpec:
""" """
A base class for specifying the KV cache format of one layer. A base class for specifying the KV cache format of one layer.
""" """
...@@ -55,7 +55,7 @@ class KVCacheSpecBase: ...@@ -55,7 +55,7 @@ class KVCacheSpecBase:
@dataclass @dataclass
class FullAttentionSpec(KVCacheSpecBase): class FullAttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
...@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase): ...@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass @dataclass
class KVCacheTensor: class KVCacheTensor:
""" """
...@@ -89,6 +86,18 @@ class KVCacheTensor: ...@@ -89,6 +86,18 @@ class KVCacheTensor:
size: int # The size of KV cache Tensor in bytes size: int # The size of KV cache Tensor in bytes
@dataclass
class KVCacheGroupSpec:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
@dataclass @dataclass
class KVCacheConfig: class KVCacheConfig:
""" """
...@@ -99,17 +108,24 @@ class KVCacheConfig: ...@@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer""" """layer_name -> how to initialize KV cache for that layer"""
tensors: dict[str, KVCacheTensor] tensors: dict[str, KVCacheTensor]
""" """
A list of kv-cache groups. Each group includes a set of layers with The kv cache groups of the model.
the same kv-cache spec, and the total page_size of layers inside a group The layers in the models are repeated with some patterns, e.g., a model
is same across all groups (as the KVCacheManager only supports allocating with 10 full attention layers and 20 sliding window attention layers can be
pages of the same size). For example: regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
1. A model only uses full attention: one group with all layers in the model. The KVCacheManager allocates different block tables for each of the 3 layers
2. (not implemented yet) A model with the same number of full attention in the pattern, and repeats each of them 10 times to generate the
layers and sliding window attention layers: two groups, one for full block_table for the 30 layers in the model.
attention layers and one for sliding window attention layers. Therefore, we can group the layers in the model into 3 groups, each of which
3. (not implemented yet) A model with 2 full attention layers and 4 sliding contains 10 layers in the model.
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
""" """
groups: list[list[str]] kv_cache_groups: list[KVCacheGroupSpec]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.output_processor import RequestState from vllm.v1.engine.output_processor import RequestState
@dataclass @dataclass
......
...@@ -65,6 +65,15 @@ class TopKTopPSampler(nn.Module): ...@@ -65,6 +65,15 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the " "native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.") "best performance, please install FlashInfer.")
self.forward = self.forward_native self.forward = self.forward_native
elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
else: else:
self.forward = self.forward_native self.forward = self.forward_native
...@@ -96,6 +105,29 @@ class TopKTopPSampler(nn.Module): ...@@ -96,6 +105,29 @@ class TopKTopPSampler(nn.Module):
return random_sample(probs, generators) return random_sample(probs, generators)
return flashinfer_sample(probs, k, p, generators) return flashinfer_sample(probs, k, p, generators)
def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p( def apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
...@@ -112,7 +144,7 @@ def apply_top_k_top_p( ...@@ -112,7 +144,7 @@ def apply_top_k_top_p(
if k is not None: if k is not None:
# Apply top-k. # Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values. # Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask top_k_mask = logits_sort < top_k_mask
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata
@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
min_p: torch.Tensor
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
min_tokens = None # impl is not vectorized
logit_bias: list[Optional[dict[int, float]]] = field(
default_factory=lambda: list())
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod
def from_sampling_metadata(
cls, metadata: SamplingMetadata,
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
device: torch.device) -> "TPUSupportedSamplingMetadata":
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
"""
metadata = cls._validate_sampling_metadata(metadata)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices)
do_argmax = torch.tensor(metadata.all_greedy,
dtype=torch.bool,
device=device)
new_metadata = cls.get_default_sampling_params(num_samples, device,
indices_do_sample=\
padded_do_sample_indices,
do_argmax=do_argmax
)
supported_params = \
TPUSupportedSamplingMetadata._get_default_params_values()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for p_name in supported_params:
old_val = getattr(metadata, p_name)
new_val = getattr(new_metadata, p_name)
if isinstance(old_val, torch.Tensor):
new_val[:num_do_sample] = old_val
setattr(new_metadata, p_name, new_val)
xm.mark_step()
xm.wait_device_ops()
return new_metadata
@classmethod
def get_default_sampling_params(
cls,
num_samples: int,
device: torch.device,
indices_do_sample=None,
do_argmax=None) -> "TPUSupportedSamplingMetadata":
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value = \
TPUSupportedSamplingMetadata._get_default_params_values()
init_kwargs = dict()
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
"""Sampler layer implementing TPU supported operations."""
import torch
import torch.nn as nn
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
greedy_sampled = self.greedy_sample(logits)
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, random_sampled)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits
...@@ -10,7 +10,8 @@ class NgramProposer: ...@@ -10,7 +10,8 @@ class NgramProposer:
def propose( def propose(
self, self,
context_token_ids: np.ndarray, context_token_ids: np.ndarray,
n: int, min_n: int,
max_n: int,
k: int, k: int,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern """Proposes the next sequence of tokens based on n-gram pattern
...@@ -21,7 +22,8 @@ class NgramProposer: ...@@ -21,7 +22,8 @@ class NgramProposer:
Args: Args:
context_token_ids: Numpy array of token IDs representing the context_token_ids: Numpy array of token IDs representing the
context sequence. context sequence.
n: Length of the n-gram to match. min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return than k tokens follow the match, we will return
the maximum amount of tokens until the end. the maximum amount of tokens until the end.
...@@ -32,14 +34,21 @@ class NgramProposer: ...@@ -32,14 +34,21 @@ class NgramProposer:
None: If no matching n-gram pattern is found. None: If no matching n-gram pattern is found.
Example: Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4: If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
- The last 2 tokens [2,3] will be matched against the previous - The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4]. 4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that - Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match. we only have three tokens after the match.
""" """
return _find_subarray_kmp(context_token_ids, n, k) # TODO(woosuk): Optimize this.
for n in range(max_n, min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
@jit(nopython=True) @jit(nopython=True)
......
...@@ -9,7 +9,6 @@ from vllm.config import VllmConfig ...@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar) StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
...@@ -47,6 +46,9 @@ class StructuredOutputManager: ...@@ -47,6 +46,9 @@ class StructuredOutputManager:
if self.backend is None: if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar": if backend_name == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import (
XgrammarBackend)
self.backend = XgrammarBackend(self.vllm_config) self.backend = XgrammarBackend(self.vllm_config)
else: else:
raise ValueError( raise ValueError(
......
...@@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs( tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
...@@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions, def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar: grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON: if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(grammar_spec, ctx = self.compiler.compile_json_schema(
any_whitespace=False) grammar_spec, any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.JSON_OBJECT: elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar() ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR: elif request_type == StructuredOutputOptions.GRAMMAR:
......
...@@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin ...@@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
...@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()( self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self)) weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
...@@ -150,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -150,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram. assert self.speculative_config.method == "ngram", \
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.drafter = NgramProposer() self.drafter = NgramProposer()
...@@ -159,7 +159,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -159,7 +159,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second. # This usually takes less than 1 second.
self.drafter.propose( self.drafter.propose(
np.zeros(1024, dtype=np.int32), np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
...@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens], self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
# Prepare for cascade attention if needed. # Prepare for cascade attention if enabled & beneficial.
common_prefix_len = self._compute_cascade_attn_prefix_len( common_prefix_len = 0
num_scheduled_tokens, if self.cascade_attn_enabled:
scheduler_output.num_common_prefix_blocks, common_prefix_len = self._compute_cascade_attn_prefix_len(
) num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
...@@ -1151,7 +1155,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1151,7 +1155,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose( drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx], self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
if drafter_output is None or len(drafter_output) == 0: if drafter_output is None or len(drafter_output) == 0:
...@@ -1506,34 +1511,46 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1506,34 +1511,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
if len(kv_cache_config.groups) > 1: if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError( raise NotImplementedError(
"Hybrid models with more than one KV cache type are not " "Hybrid models with more than one KV cache type are not "
"supported yet.") "supported yet.")
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): for kv_cache_group in kv_cache_config.kv_cache_groups:
tensor_config = kv_cache_config.tensors[layer_name] kv_cache_spec = kv_cache_group.kv_cache_spec
assert tensor_config.size % layer_spec.page_size_bytes == 0 for layer_name in kv_cache_group.layer_names:
num_blocks = tensor_config.size // layer_spec.page_size_bytes tensor_config = kv_cache_config.tensors[layer_name]
if isinstance(layer_spec, FullAttentionSpec): assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, # `num_blocks` is the number of blocks the model runner can use.
layer_spec.head_size) # `kv_cache_config.num_blocks` is the number of blocks that
dtype = layer_spec.dtype # KVCacheManager may allocate.
kv_caches[layer_name] = torch.zeros(kv_cache_shape, # Since different GPUs may have different number of layers and
dtype=dtype, # different memory capacities, `num_blocks` can be different on
device=self.device) # different GPUs, and `kv_cache_config.num_blocks` is set to
else: # the min of all `num_blocks`. Verify it here.
raise NotImplementedError assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache( bind_kv_cache(
kv_caches, kv_caches,
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context. Attention module in the static forward context.
...@@ -1545,7 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1545,7 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE): if isinstance(attn_module, FusedMoE):
continue continue
...@@ -1558,7 +1575,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1558,7 +1575,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=self.kv_cache_dtype,
use_mla=use_mla) use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
......
...@@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase ...@@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase): class Worker(WorkerBase):
...@@ -185,7 +185,7 @@ class Worker(WorkerBase): ...@@ -185,7 +185,7 @@ class Worker(WorkerBase):
return int(available_kv_cache_memory) return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
......
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -23,18 +24,21 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality ...@@ -23,18 +24,21 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata) PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,6 +46,8 @@ logger = init_logger(__name__) ...@@ -42,6 +46,8 @@ logger = init_logger(__name__)
# FIXME(woosuk): Find a more reliable way to prevent possible bugs. # FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000 _PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1 INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
class TPUModelRunner: class TPUModelRunner:
...@@ -68,6 +74,10 @@ class TPUModelRunner: ...@@ -68,6 +74,10 @@ class TPUModelRunner:
scheduler_config = self.scheduler_config scheduler_config = self.scheduler_config
parallel_config = self.parallel_config parallel_config = self.parallel_config
self.device = device self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
...@@ -138,8 +148,10 @@ class TPUModelRunner: ...@@ -138,8 +148,10 @@ class TPUModelRunner:
device="cpu") device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros( self.block_table_cpu = torch.zeros(
(self.max_num_tokens, self.max_num_blocks_per_req), (self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype, dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu") device="cpu")
...@@ -267,6 +279,9 @@ class TPUModelRunner: ...@@ -267,6 +279,9 @@ class TPUModelRunner:
req_data.num_computed_tokens) req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index) req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
...@@ -284,13 +299,17 @@ class TPUModelRunner: ...@@ -284,13 +299,17 @@ class TPUModelRunner:
# Condense the batched states if there are empty indices. # Condense the batched states if there are empty indices.
if removed_req_indices: if removed_req_indices:
self.input_batch.condense(removed_req_indices) self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
assert self.model is not None assert self.model is not None
return self.model return self.model
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context. Attention module in the static forward context.
...@@ -301,7 +320,7 @@ class TPUModelRunner: ...@@ -301,7 +320,7 @@ class TPUModelRunner:
forward_ctx = self.vllm_config.compilation_config.static_forward_context forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window, # TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA. # cross-attention, MLA.
...@@ -447,6 +466,8 @@ class TPUModelRunner: ...@@ -447,6 +466,8 @@ class TPUModelRunner:
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
padded_num_reqs = _get_padded_num_reqs_with_upper_limit( padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs) num_reqs, self.max_num_reqs)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device) logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices return attn_metadata, logits_indices
...@@ -576,7 +597,14 @@ class TPUModelRunner: ...@@ -576,7 +597,14 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids input_ids = self.input_ids
inputs_embeds = None inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
# Run the decoder # Run the decoder
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
...@@ -585,12 +613,13 @@ class TPUModelRunner: ...@@ -585,12 +613,13 @@ class TPUModelRunner:
kv_caches=self.kv_caches, kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
num_reqs = self.input_batch.num_reqs selected_token_ids = self.model.sample_from_hidden(
selected_token_ids = self.model.compute_logits(hidden_states, hidden_states, tpu_sampling_metadata)
logits_indices, None) # Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs] selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Then, let's update the cache state. # Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None assert req_id is not None
...@@ -607,7 +636,6 @@ class TPUModelRunner: ...@@ -607,7 +636,6 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4) generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all( assert all(
req_id is not None for req_id in req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
...@@ -620,6 +648,7 @@ class TPUModelRunner: ...@@ -620,6 +648,7 @@ class TPUModelRunner:
max_gen_len = selected_token_ids.shape[-1] max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1: if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist() valid_sampled_token_ids = selected_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0] token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id self.input_batch.token_ids_cpu[i, seq_len] = token_id
...@@ -647,6 +676,12 @@ class TPUModelRunner: ...@@ -647,6 +676,12 @@ class TPUModelRunner:
logprobs=None, logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
) )
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
return model_runner_output return model_runner_output
def load_model(self) -> None: def load_model(self) -> None:
...@@ -676,11 +711,8 @@ class TPUModelRunner: ...@@ -676,11 +711,8 @@ class TPUModelRunner:
fullgraph=True, fullgraph=True,
dynamic=False) dynamic=False)
def _dummy_run( @torch.no_grad()
self, def _dummy_run(self, kv_caches, num_tokens: int) -> None:
kv_caches,
num_tokens: int,
) -> None:
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size), inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
...@@ -729,32 +761,10 @@ class TPUModelRunner: ...@@ -729,32 +761,10 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None self.model(input_ids=input_ids,
hidden_states = self.model( positions=position_ids,
input_ids=input_ids, kv_caches=kv_caches,
positions=position_ids, inputs_embeds=inputs_embeds)
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
...@@ -764,16 +774,62 @@ class TPUModelRunner: ...@@ -764,16 +774,62 @@ class TPUModelRunner:
start = time.perf_counter() start = time.perf_counter()
num_tokens = 16 num_tokens = 16
while True: while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step() xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.max_num_tokens: if num_tokens >= self.max_num_tokens:
break break
num_tokens *= 2 num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
...@@ -781,31 +837,33 @@ class TPUModelRunner: ...@@ -781,31 +837,33 @@ class TPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
if len(kv_cache_config.groups) > 1: if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError( raise NotImplementedError(
"Hybrid models with more than one KV cache type are not " "Hybrid models with more than one KV cache type are not "
"supported yet.") "supported yet.")
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): for kv_cache_group in kv_cache_config.kv_cache_groups:
tensor_config = kv_cache_config.tensors[layer_name] kv_cache_spec = kv_cache_group.kv_cache_spec
assert tensor_config.size % layer_spec.page_size_bytes == 0 for layer_name in kv_cache_group.layer_names:
num_blocks = tensor_config.size // layer_spec.page_size_bytes tensor_config = kv_cache_config.tensors[layer_name]
if isinstance(layer_spec, FullAttentionSpec): assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, if isinstance(kv_cache_spec, FullAttentionSpec):
layer_spec.head_size) kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
dtype = layer_spec.dtype num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
tpu_k_cache = torch.zeros(kv_cache_shape, dtype = kv_cache_spec.dtype
dtype=dtype,
device=self.device) tpu_k_cache = torch.zeros(kv_cache_shape,
tpu_v_cache = torch.zeros_like(tpu_k_cache) dtype=dtype,
device=self.device)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) tpu_v_cache = torch.zeros_like(tpu_k_cache)
else:
raise NotImplementedError kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
bind_kv_cache( bind_kv_cache(
kv_caches, kv_caches,
...@@ -818,6 +876,13 @@ class ModelWrapperV1(nn.Module): ...@@ -818,6 +876,13 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
super().__init__() super().__init__()
self.model = model self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward( def forward(
self, self,
...@@ -826,7 +891,7 @@ class ModelWrapperV1(nn.Module): ...@@ -826,7 +891,7 @@ class ModelWrapperV1(nn.Module):
kv_caches: list[tuple[torch.Tensor, torch.Tensor]], kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model.
Args: Args:
input_ids: The input token IDs of shape [num_tokens]. input_ids: The input token IDs of shape [num_tokens].
...@@ -837,7 +902,6 @@ class ModelWrapperV1(nn.Module): ...@@ -837,7 +902,6 @@ class ModelWrapperV1(nn.Module):
hidden_size]. It is used for multimodal models. hidden_size]. It is used for multimodal models.
""" """
assert self.model is not None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
...@@ -846,17 +910,33 @@ class ModelWrapperV1(nn.Module): ...@@ -846,17 +910,33 @@ class ModelWrapperV1(nn.Module):
return hidden_states return hidden_states
@torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_hidden(
def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
logits_indices: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata,
sampling_metadata, ) -> torch.Tensor:
) -> Optional[torch.Tensor]: """
hidden_states = hidden_states[logits_indices] Sample with xla-friendly function. This function is to be traced
logits = self.model.compute_logits(hidden_states, sampling_metadata) separately from `forward` for lighter compilation overhead.
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) """
return selected_token_ids # Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs): def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs) return self.model.get_multimodal_embeddings(*args, **kwargs)
...@@ -876,5 +956,5 @@ def _get_padded_token_len(x: int) -> int: ...@@ -876,5 +956,5 @@ def _get_padded_token_len(x: int) -> int:
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length() res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit) return min(res, upper_limit)
...@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -189,7 +189,7 @@ class TPUWorker: ...@@ -189,7 +189,7 @@ class TPUWorker:
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
......
...@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0): ...@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
self.device: Optional[torch.device] = None self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation.""" """Get specifications for KV cache implementation."""
raise NotImplementedError raise NotImplementedError
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__)
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_kwargs: BatchedTensorInputs
@classmethod
def empty(cls, device):
return ModelInput(input_tokens=torch.empty(0, device=device),
input_positions=torch.empty(0, device=device),
attn_metadata=None,
seq_lens=[],
query_lens=[],
multi_modal_kwargs={})
class OpenVINOModelRunner(ModelRunnerBase):
def __init__(
self,
ov_core: ov.Core,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.ov_core = ov_core
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = self.model_config.get_sliding_window()
self.block_size = self.cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config,
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInput:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
"""
input_tokens: List[int] = []
input_positions: List[int] = []
seq_lens: List[int] = []
past_lens: List[int] = []
query_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
subsequence_begins: List[int] = []
block_indices: List[int] = []
block_indices_begins: List[int] = []
# initialize beginning of prefix sums
subsequence_begins.append(0)
block_indices_begins.append(0)
if len(seq_group_metadata_list) == 0:
return ModelInput.empty(self.device)
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
is_prompt = seq_group_metadata.is_prompt
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
computed_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
computed_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
computed_len + seq_group_metadata.token_chunk_size,
)
if is_prompt:
tokens = seq_data.get_token_ids()[computed_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
computed_len = len(computed_block_nums) * self.block_size
tokens = tokens[computed_len:]
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if self.sliding_window is not None:
# chunked prefill doesn't support sliding window.
assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# prompt phase w/o prefix_caching, chunked_prefill
pass
block_indices.extend(block_table)
block_indices_begins.append(block_indices_begins[-1] +
len(block_table))
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if self.sliding_window is not None and not is_prompt:
seq_len = min(seq_len, self.sliding_window)
computed_len = seq_len - 1
seq_lens.append(seq_len)
query_len = seq_len - computed_len
query_lens.append(query_len)
input_tokens.extend(tokens)
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len)
if is_prompt:
assert len(seq_ids) == 1
else:
assert (
query_len == 1
), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len)
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
if self.mm_registry.has_processor(self.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map, )
max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
past_lens_tensor = torch.tensor(past_lens,
dtype=torch.int32,
device=self.device) # type: ignore
subsequence_begins_tensor = torch.tensor(
subsequence_begins, dtype=torch.int32,
device=self.device) # type: ignore
block_indices_tensor = torch.tensor(block_indices,
dtype=torch.int32,
device=self.device) # type: ignore
block_indices_begins_tensor = torch.tensor(
block_indices_begins, dtype=torch.int32,
device=self.device) # type: ignore
max_context_len = max(seq_lens)
max_context_len_tensor = torch.tensor(
max_context_len, dtype=torch.int32,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return ModelInput(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
multi_modal_kwargs=multi_modal_kwargs,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, BatchedTensorInputs]:
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
multi_modal_kwargs,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens,
self.device,
pin_memory=False,
)
return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_kwargs,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
) -> Optional[SamplerOutput]:
(
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_kwargs,
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids":
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}
with set_forward_context(attn_metadata, self.vllm_config, 0):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
def prepare_model_input(self, *args, **kwargs):
raise NotImplementedError
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
"""An OpenVINO worker class."""
from typing import Any, Dict, List, Optional, Tuple
import openvino as ov
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
class OpenVINOCacheEngine:
"""Manages the KV cache for OpenVINO backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
ov_core: ov.Core,
ov_device: str,
) -> None:
assert device_config.device_type == "openvino"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
if device_config.device.type == "cpu" and \
cache_config.cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
self.head_size += 8
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend with a CPU target device, because we want
# to reuse KV cache management in the scheduler.
self.num_device_blocks = cache_config.num_gpu_blocks
self.num_swap_blocks = cache_config.num_cpu_blocks
# Get attention backend.
self.attn_backend = get_attn_backend(
self.head_size,
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Initialize the cache.
self.kv_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_kv_cache(
self.num_device_blocks, ov_core,
ov_device)
# Initialize the swap.
self.swap_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_swap_cache(
self.num_swap_blocks, ov_device)
def _allocate_kv_cache(
self,
num_blocks: int,
ov_core: ov.Core,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates KV cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if current_platform.is_openvino_cpu():
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
kv_cache.append((key_blocks, value_blocks))
else:
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1],
v_block_shape[3], v_block_shape[2])
remote_context = ov_core.get_default_context(ov_device)
for _ in range(self.num_layers):
key_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(k_block_shape),
{})
value_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(v_block_shape),
{})
kv_cache.append((key_blocks, value_blocks))
return kv_cache
def _allocate_swap_cache(
self,
num_blocks: int,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates swap cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if num_blocks == 0:
return swap_cache
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3],
v_block_shape[2])
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
swap_cache.append((key_blocks, value_blocks))
return swap_cache
def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(swap_tensor, kv_tensor,
src_to_dst)
def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(kv_tensor, swap_tensor,
src_to_dst)
def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None:
if (len(src_to_dsts) > 0):
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: ov.Type,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
if cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
head_size += 8
key_cache_block = block_size * num_kv_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = cache_dtype.size
return dtype_size * total
class OpenVINOWorker(LoRANotSupportedWorkerBase):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is
responsible for maintaining the KV cache and executing the model on the
OpenVINO backend.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
WorkerBase.__init__(self, vllm_config)
self.ov_core = ov.Core()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
self.ov_core,
vllm_config=self.vllm_config,
kv_cache_dtype=self.vllm_config.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: OpenVINOCacheEngine
self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured
KV cache space.
"""
# For OpenVINO backend, in case of CPU device, the block number will be
# calculated based on the openvino_kvcache_space_bytes.
cache_block_size = self.get_cache_block_size_bytes()
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
if current_platform.is_openvino_cpu():
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
num_swap_blocks = 0
else:
if kvcache_space_bytes > 0:
logger.info("KV_CACHE size was explicitly configured via "
"VLLM_OPENVINO_KVCACHE_SPACE environment "
"variable, ignoring profiling run.")
kv_cache_size = kvcache_space_bytes
else:
try:
kv_cache_size = self.profile_run()
except Exception as err:
raise RuntimeError(
"The error occurred during profile run. This might be "
"due to insufficient GPU memory. Consider decreasing "
"`max_model_len` to limit the maximum simultaneously "
"processed tokens.") from err
num_device_blocks = int(kv_cache_size // cache_block_size)
num_swap_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
return num_device_blocks, num_swap_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Swappable CPU memory is only
supported on GPU.
For CPU, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
num_device_blocks = num_gpu_blocks
num_swap_blocks = num_cpu_blocks
if current_platform.is_openvino_cpu():
assert (num_swap_blocks == 0
), f"{type(self)} does not support swappable cache for CPU"
self._validate_num_blocks(num_device_blocks)
self.cache_config.num_gpu_blocks = num_device_blocks
self.cache_config.num_cpu_blocks = num_swap_blocks
# Initialize the cache.
self._init_cache_engine()
def _validate_num_blocks(self, num_blocks: int) -> None:
"""Raise errors if the num_blocks is invalid."""
if num_blocks <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
"when initializing the engine.")
def _init_cache_engine(self) -> None:
ov_device = envs.VLLM_OPENVINO_DEVICE
self.cache_engine = OpenVINOCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
self.ov_core,
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
bind_kv_cache(self.compilation_config.static_forward_context,
[self.kv_cache])
self.model_runner.block_size = self.cache_engine.block_size
assert self.kv_cache is not None
# Populate the cache to warmup the memory
if current_platform.is_openvino_cpu():
for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0
value_cache.data[:] = 0
def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_in(src_to_dst)
def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_out(src_to_dst)
def cache_copy(
self,
blocks_to_copy: List[Tuple[int, int]],
) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
"blocks_to_swap_in": execute_model_req.blocks_to_swap_in,
"blocks_to_swap_out": execute_model_req.blocks_to_swap_out,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
if current_platform.is_openvino_cpu():
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
else:
self.cache_swap_in(blocks_to_swap_in)
self.cache_swap_out(blocks_to_swap_out)
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.kv_cache)
# OpenVINO worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block."""
return OpenVINOCacheEngine.get_cache_block_size(
self.cache_config.block_size,
self.cache_config.cache_dtype,
self.model_config,
self.parallel_config,
)
def profile_run(self) -> int:
ov_device = envs.VLLM_OPENVINO_DEVICE
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to use profile run."
import openvino.properties.device as device
import openvino.properties.intel_gpu as intel_gpu
ov_core = self.ov_core
cache_config = self.cache_config
model_config = self.model_config
parallel_config = self.parallel_config
device_config = self.device_config
input_registry = INPUT_REGISTRY
mm_registry = MULTIMODAL_REGISTRY
mm_registry.init_mm_limits_per_prompt(model_config)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
def model_profile_run():
top_k = model_config.get_vocab_size() - 1
sampling_params = SamplingParams(top_p=0.99, top_k=top_k)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
tmp_cache_config = CacheConfig(cache_config.block_size,
cache_config.gpu_memory_utilization,
cache_config.swap_space_bytes,
"auto")
tmp_cache_config.num_gpu_blocks = 1
tmp_cache_config.num_cpu_blocks = 0
tmp_cache_config.cache_dtype = cache_config.cache_dtype
profiling_cache_engine = OpenVINOCacheEngine(
tmp_cache_config, model_config, parallel_config, device_config,
ov_core, ov_device)
# Profile memory usage with max_num_sequences sequences and the
# total # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
block_size = cache_config.block_size
seq_num_blocks = (seq_len + block_size - 1) // block_size
dummy_data = input_registry \
.dummy_data_for_profiling(model_config,
seq_len,
mm_registry)
block_tables = [[0] * seq_num_blocks] * max_num_seqs
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
lora_request=None,
multi_modal_data=dummy_data.multi_modal_data)
seqs.append(seq)
self.model_runner.block_size = tmp_cache_config.block_size
bind_kv_cache(self.compilation_config.static_forward_context,
profiling_cache_engine.kv_cache)
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache(self.compilation_config.static_forward_context, [[
torch.tensor([])
for _ in range(len(profiling_cache_engine.kv_cache))
]])
del profiling_cache_engine
logger.info(
"Start profiling run with dummy inputs to evaluate "
"memory usage for %s. It might take a while.", ov_device)
model_profile_run()
gpu_device_type = ov_core.get_property(ov_device, device.type)
memory_statistics = \
ov_core.get_property(ov_device, intel_gpu.memory_statistics)
memory_utilization = cache_config.gpu_memory_utilization
if gpu_device_type == device.Type.INTEGRATED and \
memory_utilization >= 0.9:
logger.warning(
"iGPU is used with high gpu_memory_utilization=%f "
"value. This may cause low performance due to "
"occupying the majority of available system "
"memory. Please consider decreasing "
"gpu_memory_utilization or explicitly setting "
"`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
"variable.", memory_utilization)
# sum up all used device memory
device_memory_types = ["cl_mem", "usm_device"]
used_device_mem = \
sum(memory_statistics.get(key, 0) for key in device_memory_types)
if gpu_device_type == device.Type.INTEGRATED:
used_device_mem += memory_statistics.get("usm_host", 0)
# there could be unaccounted extra memory reserved by kernels, kept
# in memory pools, etc
# therefore, add a threshold to account for this
used_memory_threshold = 1.1
used_device_mem *= used_memory_threshold
total_device_memory = \
ov_core.get_property(ov_device, intel_gpu.device_total_mem_size)
def format_memory_size(size) -> str:
units = ["B", "KB", "MB", "GB"]
unit_index = 0
while size > 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
return f"{size:.2f} {units[unit_index]}"
total_device_memory_str = \
format(format_memory_size(total_device_memory))
used_device_memory_str = \
format(format_memory_size(used_device_mem))
logger.info(
"Total %s memory: %s. "
"Amount of memory required to run the model with "
"max_num_batched_tokens=%d: %s.", ov_device,
total_device_memory_str,
self.scheduler_config.max_num_batched_tokens,
used_device_memory_str)
if used_device_mem >= total_device_memory:
raise RuntimeError(
f"The required memory size {used_device_memory_str} for model "
"is higher than the total available device "
"memory {total_device_memory_str}. Please consider to "
"decrease `max_num_batched_tokens` or increase "
"`gpu_memory_utilization`")
return total_device_memory * memory_utilization - used_device_mem
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment