Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
......@@ -120,7 +120,7 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
supported_backends = ["xgrammar"]
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
......@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash,
......
......@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)
return output
def get_kv_cache_specs(self) -> list[KVCacheSpec]:
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output
......@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
......@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()
return [memory_tensor.item()]
......@@ -5,6 +5,7 @@ import pickle
import signal
import sys
import time
import traceback
import weakref
from dataclasses import dataclass
from enum import Enum, auto
......@@ -370,6 +371,9 @@ class WorkerProc:
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
......
......@@ -11,7 +11,7 @@ logger = init_logger(__name__)
@dataclass
class KVCacheSpecBase:
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
"""
......@@ -55,7 +55,7 @@ class KVCacheSpecBase:
@dataclass
class FullAttentionSpec(KVCacheSpecBase):
class FullAttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
......@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass
class KVCacheTensor:
"""
......@@ -89,6 +86,18 @@ class KVCacheTensor:
size: int # The size of KV cache Tensor in bytes
@dataclass
class KVCacheGroupSpec:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
@dataclass
class KVCacheConfig:
"""
......@@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer"""
tensors: dict[str, KVCacheTensor]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
is same across all groups (as the KVCacheManager only supports allocating
pages of the same size). For example:
1. A model only uses full attention: one group with all layers in the model.
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
groups: list[list[str]]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec
kv_cache_groups: list[KVCacheGroupSpec]
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.output_processor import RequestState
from vllm.v1.engine.output_processor import RequestState
@dataclass
......
......@@ -65,6 +65,15 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
......@@ -96,6 +105,29 @@ class TopKTopPSampler(nn.Module):
return random_sample(probs, generators)
return flashinfer_sample(probs, k, p, generators)
def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
......@@ -112,7 +144,7 @@ def apply_top_k_top_p(
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata
@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
min_p: torch.Tensor
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
min_tokens = None # impl is not vectorized
logit_bias: list[Optional[dict[int, float]]] = field(
default_factory=lambda: list())
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod
def from_sampling_metadata(
cls, metadata: SamplingMetadata,
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
device: torch.device) -> "TPUSupportedSamplingMetadata":
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
"""
metadata = cls._validate_sampling_metadata(metadata)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices)
do_argmax = torch.tensor(metadata.all_greedy,
dtype=torch.bool,
device=device)
new_metadata = cls.get_default_sampling_params(num_samples, device,
indices_do_sample=\
padded_do_sample_indices,
do_argmax=do_argmax
)
supported_params = \
TPUSupportedSamplingMetadata._get_default_params_values()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for p_name in supported_params:
old_val = getattr(metadata, p_name)
new_val = getattr(new_metadata, p_name)
if isinstance(old_val, torch.Tensor):
new_val[:num_do_sample] = old_val
setattr(new_metadata, p_name, new_val)
xm.mark_step()
xm.wait_device_ops()
return new_metadata
@classmethod
def get_default_sampling_params(
cls,
num_samples: int,
device: torch.device,
indices_do_sample=None,
do_argmax=None) -> "TPUSupportedSamplingMetadata":
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value = \
TPUSupportedSamplingMetadata._get_default_params_values()
init_kwargs = dict()
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
"""Sampler layer implementing TPU supported operations."""
import torch
import torch.nn as nn
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
greedy_sampled = self.greedy_sample(logits)
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, random_sampled)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits
......@@ -10,7 +10,8 @@ class NgramProposer:
def propose(
self,
context_token_ids: np.ndarray,
n: int,
min_n: int,
max_n: int,
k: int,
) -> Optional[np.ndarray]:
"""Proposes the next sequence of tokens based on n-gram pattern
......@@ -21,7 +22,8 @@ class NgramProposer:
Args:
context_token_ids: Numpy array of token IDs representing the
context sequence.
n: Length of the n-gram to match.
min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
......@@ -32,14 +34,21 @@ class NgramProposer:
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
return _find_subarray_kmp(context_token_ids, n, k)
# TODO(woosuk): Optimize this.
for n in range(max_n, min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
@jit(nopython=True)
......
......@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
if TYPE_CHECKING:
import numpy as np
......@@ -47,6 +46,9 @@ class StructuredOutputManager:
if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import (
XgrammarBackend)
self.backend = XgrammarBackend(self.vllm_config)
else:
raise ValueError(
......
......@@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
......@@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
ctx = self.compiler.compile_json_schema(
grammar_spec, any_whitespace=not self.disable_any_whitespace)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
......
......@@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
import xgrammar as xgr
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
......@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
......@@ -150,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
......@@ -159,7 +159,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens,
)
self.rejection_sampler = RejectionSampler()
......@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Prepare for cascade attention if needed.
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
......@@ -1151,7 +1155,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.prompt_lookup_max,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0:
......@@ -1506,34 +1511,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
raise NotImplementedError
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
......@@ -1545,7 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
......@@ -1558,7 +1575,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
......
......@@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase):
......@@ -185,7 +185,7 @@ class Worker(WorkerBase):
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
......
......@@ -11,6 +11,7 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
......@@ -23,18 +24,21 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
......@@ -42,6 +46,8 @@ logger = init_logger(__name__)
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
class TPUModelRunner:
......@@ -68,6 +74,10 @@ class TPUModelRunner:
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
......@@ -138,8 +148,10 @@ class TPUModelRunner:
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, self.max_num_blocks_per_req),
(self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")
......@@ -267,6 +279,9 @@ class TPUModelRunner:
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
......@@ -284,13 +299,17 @@ class TPUModelRunner:
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
assert self.model is not None
return self.model
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
......@@ -301,7 +320,7 @@ class TPUModelRunner:
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
......@@ -447,6 +466,8 @@ class TPUModelRunner:
# TODO: Support prompt logprobs.
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
......@@ -576,7 +597,14 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
......@@ -585,12 +613,13 @@ class TPUModelRunner:
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = self.input_batch.num_reqs
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Then, let's update the cache state.
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
......@@ -607,7 +636,6 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
......@@ -620,6 +648,7 @@ class TPUModelRunner:
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
......@@ -647,6 +676,12 @@ class TPUModelRunner:
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
return model_runner_output
def load_model(self) -> None:
......@@ -676,11 +711,8 @@ class TPUModelRunner:
fullgraph=True,
dynamic=False)
def _dummy_run(
self,
kv_caches,
num_tokens: int,
) -> None:
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
......@@ -729,32 +761,10 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)
self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
def capture_model(self) -> None:
"""Compile the model."""
......@@ -764,16 +774,62 @@ class TPUModelRunner:
start = time.perf_counter()
num_tokens = 16
while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
......@@ -781,31 +837,33 @@ class TPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
......@@ -818,6 +876,13 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
......@@ -826,7 +891,7 @@ class ModelWrapperV1(nn.Module):
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
......@@ -837,7 +902,6 @@ class ModelWrapperV1(nn.Module):
hidden_size]. It is used for multimodal models.
"""
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
......@@ -846,17 +910,33 @@ class ModelWrapperV1(nn.Module):
return hidden_states
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
......@@ -876,5 +956,5 @@ def _get_padded_token_len(x: int) -> int:
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
......@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
......@@ -189,7 +189,7 @@ class TPUWorker:
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
......
......@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase
logger = init_logger(__name__)
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_kwargs: BatchedTensorInputs
@classmethod
def empty(cls, device):
return ModelInput(input_tokens=torch.empty(0, device=device),
input_positions=torch.empty(0, device=device),
attn_metadata=None,
seq_lens=[],
query_lens=[],
multi_modal_kwargs={})
class OpenVINOModelRunner(ModelRunnerBase):
def __init__(
self,
ov_core: ov.Core,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.ov_core = ov_core
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = self.model_config.get_sliding_window()
self.block_size = self.cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config,
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInput:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
"""
input_tokens: List[int] = []
input_positions: List[int] = []
seq_lens: List[int] = []
past_lens: List[int] = []
query_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
subsequence_begins: List[int] = []
block_indices: List[int] = []
block_indices_begins: List[int] = []
# initialize beginning of prefix sums
subsequence_begins.append(0)
block_indices_begins.append(0)
if len(seq_group_metadata_list) == 0:
return ModelInput.empty(self.device)
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
is_prompt = seq_group_metadata.is_prompt
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
computed_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
computed_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
computed_len + seq_group_metadata.token_chunk_size,
)
if is_prompt:
tokens = seq_data.get_token_ids()[computed_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
computed_len = len(computed_block_nums) * self.block_size
tokens = tokens[computed_len:]
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if self.sliding_window is not None:
# chunked prefill doesn't support sliding window.
assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# prompt phase w/o prefix_caching, chunked_prefill
pass
block_indices.extend(block_table)
block_indices_begins.append(block_indices_begins[-1] +
len(block_table))
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if self.sliding_window is not None and not is_prompt:
seq_len = min(seq_len, self.sliding_window)
computed_len = seq_len - 1
seq_lens.append(seq_len)
query_len = seq_len - computed_len
query_lens.append(query_len)
input_tokens.extend(tokens)
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len)
if is_prompt:
assert len(seq_ids) == 1
else:
assert (
query_len == 1
), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len)
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
if self.mm_registry.has_processor(self.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map, )
max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
past_lens_tensor = torch.tensor(past_lens,
dtype=torch.int32,
device=self.device) # type: ignore
subsequence_begins_tensor = torch.tensor(
subsequence_begins, dtype=torch.int32,
device=self.device) # type: ignore
block_indices_tensor = torch.tensor(block_indices,
dtype=torch.int32,
device=self.device) # type: ignore
block_indices_begins_tensor = torch.tensor(
block_indices_begins, dtype=torch.int32,
device=self.device) # type: ignore
max_context_len = max(seq_lens)
max_context_len_tensor = torch.tensor(
max_context_len, dtype=torch.int32,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return ModelInput(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
multi_modal_kwargs=multi_modal_kwargs,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, BatchedTensorInputs]:
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
multi_modal_kwargs,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens,
self.device,
pin_memory=False,
)
return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_kwargs,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
) -> Optional[SamplerOutput]:
(
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_kwargs,
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids":
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}
with set_forward_context(attn_metadata, self.vllm_config, 0):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
def prepare_model_input(self, *args, **kwargs):
raise NotImplementedError
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
"""An OpenVINO worker class."""
from typing import Any, Dict, List, Optional, Tuple
import openvino as ov
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.utils import bind_kv_cache
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
class OpenVINOCacheEngine:
"""Manages the KV cache for OpenVINO backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
ov_core: ov.Core,
ov_device: str,
) -> None:
assert device_config.device_type == "openvino"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
if device_config.device.type == "cpu" and \
cache_config.cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
self.head_size += 8
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend with a CPU target device, because we want
# to reuse KV cache management in the scheduler.
self.num_device_blocks = cache_config.num_gpu_blocks
self.num_swap_blocks = cache_config.num_cpu_blocks
# Get attention backend.
self.attn_backend = get_attn_backend(
self.head_size,
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Initialize the cache.
self.kv_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_kv_cache(
self.num_device_blocks, ov_core,
ov_device)
# Initialize the swap.
self.swap_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_swap_cache(
self.num_swap_blocks, ov_device)
def _allocate_kv_cache(
self,
num_blocks: int,
ov_core: ov.Core,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates KV cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if current_platform.is_openvino_cpu():
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
kv_cache.append((key_blocks, value_blocks))
else:
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1],
v_block_shape[3], v_block_shape[2])
remote_context = ov_core.get_default_context(ov_device)
for _ in range(self.num_layers):
key_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(k_block_shape),
{})
value_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(v_block_shape),
{})
kv_cache.append((key_blocks, value_blocks))
return kv_cache
def _allocate_swap_cache(
self,
num_blocks: int,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates swap cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if num_blocks == 0:
return swap_cache
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3],
v_block_shape[2])
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
swap_cache.append((key_blocks, value_blocks))
return swap_cache
def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(swap_tensor, kv_tensor,
src_to_dst)
def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(kv_tensor, swap_tensor,
src_to_dst)
def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None:
if (len(src_to_dsts) > 0):
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: ov.Type,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
if cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
head_size += 8
key_cache_block = block_size * num_kv_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = cache_dtype.size
return dtype_size * total
class OpenVINOWorker(LoRANotSupportedWorkerBase):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is
responsible for maintaining the KV cache and executing the model on the
OpenVINO backend.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
WorkerBase.__init__(self, vllm_config)
self.ov_core = ov.Core()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
self.ov_core,
vllm_config=self.vllm_config,
kv_cache_dtype=self.vllm_config.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: OpenVINOCacheEngine
self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured
KV cache space.
"""
# For OpenVINO backend, in case of CPU device, the block number will be
# calculated based on the openvino_kvcache_space_bytes.
cache_block_size = self.get_cache_block_size_bytes()
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
if current_platform.is_openvino_cpu():
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
num_swap_blocks = 0
else:
if kvcache_space_bytes > 0:
logger.info("KV_CACHE size was explicitly configured via "
"VLLM_OPENVINO_KVCACHE_SPACE environment "
"variable, ignoring profiling run.")
kv_cache_size = kvcache_space_bytes
else:
try:
kv_cache_size = self.profile_run()
except Exception as err:
raise RuntimeError(
"The error occurred during profile run. This might be "
"due to insufficient GPU memory. Consider decreasing "
"`max_model_len` to limit the maximum simultaneously "
"processed tokens.") from err
num_device_blocks = int(kv_cache_size // cache_block_size)
num_swap_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
return num_device_blocks, num_swap_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Swappable CPU memory is only
supported on GPU.
For CPU, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
num_device_blocks = num_gpu_blocks
num_swap_blocks = num_cpu_blocks
if current_platform.is_openvino_cpu():
assert (num_swap_blocks == 0
), f"{type(self)} does not support swappable cache for CPU"
self._validate_num_blocks(num_device_blocks)
self.cache_config.num_gpu_blocks = num_device_blocks
self.cache_config.num_cpu_blocks = num_swap_blocks
# Initialize the cache.
self._init_cache_engine()
def _validate_num_blocks(self, num_blocks: int) -> None:
"""Raise errors if the num_blocks is invalid."""
if num_blocks <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
"when initializing the engine.")
def _init_cache_engine(self) -> None:
ov_device = envs.VLLM_OPENVINO_DEVICE
self.cache_engine = OpenVINOCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
self.ov_core,
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
bind_kv_cache(self.compilation_config.static_forward_context,
[self.kv_cache])
self.model_runner.block_size = self.cache_engine.block_size
assert self.kv_cache is not None
# Populate the cache to warmup the memory
if current_platform.is_openvino_cpu():
for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0
value_cache.data[:] = 0
def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_in(src_to_dst)
def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_out(src_to_dst)
def cache_copy(
self,
blocks_to_copy: List[Tuple[int, int]],
) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
"blocks_to_swap_in": execute_model_req.blocks_to_swap_in,
"blocks_to_swap_out": execute_model_req.blocks_to_swap_out,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
if current_platform.is_openvino_cpu():
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
else:
self.cache_swap_in(blocks_to_swap_in)
self.cache_swap_out(blocks_to_swap_out)
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.kv_cache)
# OpenVINO worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block."""
return OpenVINOCacheEngine.get_cache_block_size(
self.cache_config.block_size,
self.cache_config.cache_dtype,
self.model_config,
self.parallel_config,
)
def profile_run(self) -> int:
ov_device = envs.VLLM_OPENVINO_DEVICE
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to use profile run."
import openvino.properties.device as device
import openvino.properties.intel_gpu as intel_gpu
ov_core = self.ov_core
cache_config = self.cache_config
model_config = self.model_config
parallel_config = self.parallel_config
device_config = self.device_config
input_registry = INPUT_REGISTRY
mm_registry = MULTIMODAL_REGISTRY
mm_registry.init_mm_limits_per_prompt(model_config)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
def model_profile_run():
top_k = model_config.get_vocab_size() - 1
sampling_params = SamplingParams(top_p=0.99, top_k=top_k)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
tmp_cache_config = CacheConfig(cache_config.block_size,
cache_config.gpu_memory_utilization,
cache_config.swap_space_bytes,
"auto")
tmp_cache_config.num_gpu_blocks = 1
tmp_cache_config.num_cpu_blocks = 0
tmp_cache_config.cache_dtype = cache_config.cache_dtype
profiling_cache_engine = OpenVINOCacheEngine(
tmp_cache_config, model_config, parallel_config, device_config,
ov_core, ov_device)
# Profile memory usage with max_num_sequences sequences and the
# total # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
block_size = cache_config.block_size
seq_num_blocks = (seq_len + block_size - 1) // block_size
dummy_data = input_registry \
.dummy_data_for_profiling(model_config,
seq_len,
mm_registry)
block_tables = [[0] * seq_num_blocks] * max_num_seqs
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
lora_request=None,
multi_modal_data=dummy_data.multi_modal_data)
seqs.append(seq)
self.model_runner.block_size = tmp_cache_config.block_size
bind_kv_cache(self.compilation_config.static_forward_context,
profiling_cache_engine.kv_cache)
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache(self.compilation_config.static_forward_context, [[
torch.tensor([])
for _ in range(len(profiling_cache_engine.kv_cache))
]])
del profiling_cache_engine
logger.info(
"Start profiling run with dummy inputs to evaluate "
"memory usage for %s. It might take a while.", ov_device)
model_profile_run()
gpu_device_type = ov_core.get_property(ov_device, device.type)
memory_statistics = \
ov_core.get_property(ov_device, intel_gpu.memory_statistics)
memory_utilization = cache_config.gpu_memory_utilization
if gpu_device_type == device.Type.INTEGRATED and \
memory_utilization >= 0.9:
logger.warning(
"iGPU is used with high gpu_memory_utilization=%f "
"value. This may cause low performance due to "
"occupying the majority of available system "
"memory. Please consider decreasing "
"gpu_memory_utilization or explicitly setting "
"`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
"variable.", memory_utilization)
# sum up all used device memory
device_memory_types = ["cl_mem", "usm_device"]
used_device_mem = \
sum(memory_statistics.get(key, 0) for key in device_memory_types)
if gpu_device_type == device.Type.INTEGRATED:
used_device_mem += memory_statistics.get("usm_host", 0)
# there could be unaccounted extra memory reserved by kernels, kept
# in memory pools, etc
# therefore, add a threshold to account for this
used_memory_threshold = 1.1
used_device_mem *= used_memory_threshold
total_device_memory = \
ov_core.get_property(ov_device, intel_gpu.device_total_mem_size)
def format_memory_size(size) -> str:
units = ["B", "KB", "MB", "GB"]
unit_index = 0
while size > 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
return f"{size:.2f} {units[unit_index]}"
total_device_memory_str = \
format(format_memory_size(total_device_memory))
used_device_memory_str = \
format(format_memory_size(used_device_mem))
logger.info(
"Total %s memory: %s. "
"Amount of memory required to run the model with "
"max_num_batched_tokens=%d: %s.", ov_device,
total_device_memory_str,
self.scheduler_config.max_num_batched_tokens,
used_device_memory_str)
if used_device_mem >= total_device_memory:
raise RuntimeError(
f"The required memory size {used_device_memory_str} for model "
"is higher than the total available device "
"memory {total_device_memory_str}. Please consider to "
"decrease `max_num_batched_tokens` or increase "
"`gpu_memory_utilization`")
return total_device_memory * memory_utilization - used_device_mem
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment