Commit 7c4f76e3 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.0

parents 2da0dd3e 51c31bc1
"""Attention layer."""
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
class Attention(nn.Module):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata)
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm._C import cache_ops, ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# Maximum context length in the batch.
max_context_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_context_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
prompt_lens_tensor,
context_lens,
max_subquery_len,
alibi_slopes,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
from functools import lru_cache
from typing import Type
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
if _can_use_flash_attn(dtype):
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
try:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention because the package is not found. "
"Please install it for better performance.")
return False
return True
...@@ -5,6 +5,8 @@ from vllm.utils import Device ...@@ -5,6 +5,8 @@ from vllm.utils import Device
_BLANK_TOKEN_ID = -1 _BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock: class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right. """A block that stores a contiguous chunk of tokens from left to right.
...@@ -55,17 +57,27 @@ class PhysicalTokenBlock: ...@@ -55,17 +57,27 @@ class PhysicalTokenBlock:
device: Device, device: Device,
block_number: int, block_number: int,
block_size: int, block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None: ) -> None:
self.device = device self.device = device
self.block_number = block_number self.block_number = block_number
self.block_size = block_size self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens
self.ref_count = 0 self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
self.computed = False
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, ' return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, ' f'block_number={self.block_number}, '
f'ref_count={self.ref_count})') f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')
# Mapping: logical block number -> physical block. # Mapping: logical block number -> physical block.
......
from typing import Optional, Union, ClassVar import enum
from dataclasses import dataclass import json
import os import os
from packaging.version import Version from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Union
import torch import torch
from packaging.version import Version
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -45,7 +50,7 @@ class ModelConfig: ...@@ -45,7 +50,7 @@ class ModelConfig:
a tag name, or a commit id. If unspecified, will use the default a tag name, or a commit id. If unspecified, will use the default
version. version.
code_revision: The specific revision to use for the model code on code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version. commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
...@@ -79,6 +84,7 @@ class ModelConfig: ...@@ -79,6 +84,7 @@ class ModelConfig:
quantization: Optional[str] = None, quantization: Optional[str] = None,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -93,11 +99,14 @@ class ModelConfig: ...@@ -93,11 +99,14 @@ class ModelConfig:
self.quantization = quantization self.quantization = quantization
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C # pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model): if not os.path.exists(model):
model_path = snapshot_download(model_id=model, model_path = snapshot_download(model_id=model,
cache_dir=download_dir, cache_dir=download_dir,
...@@ -110,8 +119,9 @@ class ModelConfig: ...@@ -110,8 +119,9 @@ class ModelConfig:
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision) code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.hf_text_config = get_hf_text_config(self.hf_config)
self.max_model_len = _get_and_verify_max_len(self.hf_config, self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len) max_model_len)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
...@@ -134,7 +144,7 @@ class ModelConfig: ...@@ -134,7 +144,7 @@ class ModelConfig:
if (f not in rocm_not_supported_load_format) if (f not in rocm_not_supported_load_format)
] ]
raise ValueError( raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. " f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are " f"Supported load format are "
f"{rocm_supported_load_format}") f"{rocm_supported_load_format}")
...@@ -163,13 +173,18 @@ class ModelConfig: ...@@ -163,13 +173,18 @@ class ModelConfig:
# Parse quantization method from the HF model config, if available. # Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None) hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None: if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower() hf_quant_method = str(hf_quant_config["quant_method"]).lower()
# If the GPTQ model is serialized in marlin format, use marlin. # If the GPTQ model is serialized in marlin format, use marlin.
if (hf_quant_method == "gptq" if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]): and hf_quant_config["is_marlin_format"]):
logger.info("The model is serialized in Marlin format. "
"Using Marlin kernel.")
hf_quant_method = "marlin" hf_quant_method = "marlin"
if self.quantization == "gptq":
self.quantization = hf_quant_method
if self.quantization is None: if self.quantization is None:
self.quantization = hf_quant_method self.quantization = hf_quant_method
elif self.quantization != hf_quant_method: elif self.quantization != hf_quant_method:
...@@ -187,8 +202,8 @@ class ModelConfig: ...@@ -187,8 +202,8 @@ class ModelConfig:
if is_hip( if is_hip(
) and self.quantization in rocm_not_supported_quantization: ) and self.quantization in rocm_not_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not supported " f"{self.quantization} quantization is currently not "
f"in ROCm.") f"supported in ROCm.")
if self.quantization != "marlin": if self.quantization != "marlin":
logger.warning( logger.warning(
f"{self.quantization} quantization is not fully " f"{self.quantization} quantization is not fully "
...@@ -205,7 +220,7 @@ class ModelConfig: ...@@ -205,7 +220,7 @@ class ModelConfig:
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
) -> None: ) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads total_num_attention_heads = self.hf_text_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0: if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError( raise ValueError(
...@@ -213,7 +228,7 @@ class ModelConfig: ...@@ -213,7 +228,7 @@ class ModelConfig:
" must be divisible by tensor parallel size " " must be divisible by tensor parallel size "
f"({tensor_parallel_size}).") f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_text_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0: if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError( raise ValueError(
...@@ -222,19 +237,29 @@ class ModelConfig: ...@@ -222,19 +237,29 @@ class ModelConfig:
f"({pipeline_parallel_size}).") f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]: def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None) """Get the sliding window size, or None if disabled.
"""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
return getattr(self.hf_text_config, "sliding_window", None)
def get_vocab_size(self) -> int: def get_vocab_size(self) -> int:
return self.hf_config.vocab_size return self.hf_text_config.vocab_size
def get_hidden_size(self) -> int: def get_hidden_size(self) -> int:
return self.hf_config.hidden_size return self.hf_text_config.hidden_size
def get_head_size(self) -> int: def get_head_size(self) -> int:
if hasattr(self.hf_config, "head_dim"): if hasattr(self.hf_text_config, "head_dim"):
return self.hf_config.head_dim return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
...@@ -246,12 +271,17 @@ class ModelConfig: ...@@ -246,12 +271,17 @@ class ModelConfig:
new_decoder_arch_falcon = ( new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)) and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config, if not new_decoder_arch_falcon and getattr(self.hf_text_config,
"multi_query", False): "multi_query", False):
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case. # Currently, tensor parallelism is not supported in this case.
return 1 return 1
# For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
attributes = [ attributes = [
# For Falcon: # For Falcon:
"n_head_kv", "n_head_kv",
...@@ -262,13 +292,13 @@ class ModelConfig: ...@@ -262,13 +292,13 @@ class ModelConfig:
"multi_query_group_num", "multi_query_group_num",
] ]
for attr in attributes: for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None) num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None: if num_kv_heads is not None:
return num_kv_heads return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is # For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads. # equal to the number of attention heads.
return self.hf_config.num_attention_heads return self.hf_text_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU.""" """Returns the number of KV heads per GPU."""
...@@ -281,7 +311,7 @@ class ModelConfig: ...@@ -281,7 +311,7 @@ class ModelConfig:
total_num_kv_heads // parallel_config.tensor_parallel_size) total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size return total_num_hidden_layers // parallel_config.pipeline_parallel_size
...@@ -294,6 +324,8 @@ class CacheConfig: ...@@ -294,6 +324,8 @@ class CacheConfig:
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage. cache_dtype: Data type for kv cache storage.
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
""" """
def __init__( def __init__(
...@@ -302,13 +334,17 @@ class CacheConfig: ...@@ -302,13 +334,17 @@ class CacheConfig:
gpu_memory_utilization: float, gpu_memory_utilization: float,
swap_space: int, swap_space: int,
cache_dtype: str, cache_dtype: str,
forced_num_gpu_blocks: Optional[int] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB self.swap_space_bytes = swap_space * _GB
self.forced_num_gpu_blocks = forced_num_gpu_blocks
self.cache_dtype = cache_dtype self.cache_dtype = cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self._verify_args() self._verify_args()
self._verify_cache_dtype() self._verify_cache_dtype()
...@@ -317,7 +353,8 @@ class CacheConfig: ...@@ -317,7 +353,8 @@ class CacheConfig:
self.num_cpu_blocks = None self.num_cpu_blocks = None
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value:str) for prometheus metrics info # convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()} return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -330,15 +367,14 @@ class CacheConfig: ...@@ -330,15 +367,14 @@ class CacheConfig:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype == "fp8_e5m2": elif self.cache_dtype == "fp8_e5m2":
if is_hip():
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
nvcc_cuda_version = get_nvcc_cuda_version() nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
raise ValueError( raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8." "FP8 is not supported when cuda version is lower than 11.8."
) )
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info( logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces " "Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. " "the GPU memory footprint and boosts the performance. "
...@@ -367,6 +403,58 @@ class CacheConfig: ...@@ -367,6 +403,58 @@ class CacheConfig:
logger.warning("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. " + msg)
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: str
extra_config: dict
def __post_init__(self):
if self.pool_type not in ("ray", ):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution. """Configuration for the distributed execution.
...@@ -381,6 +469,10 @@ class ParallelConfig: ...@@ -381,6 +469,10 @@ class ParallelConfig:
parallel and large models. parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL. fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
""" """
def __init__( def __init__(
...@@ -390,23 +482,21 @@ class ParallelConfig: ...@@ -390,23 +482,21 @@ class ParallelConfig:
worker_use_ray: bool, worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None, max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron(): self.tensor_parallel_size = tensor_parallel_size
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend. if self.world_size > 1:
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True self.worker_use_ray = True
self._verify_args() self._verify_args()
...@@ -425,15 +515,9 @@ class ParallelConfig: ...@@ -425,15 +515,9 @@ class ParallelConfig:
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.") "supported with pipeline parallelism.")
if self.ray_workers_use_nsight and not self.worker_use_ray:
# FIXME(woosuk): Fix the stability issues and re-enable the custom raise ValueError("Unable to use nsight profiling unless workers "
# all-reduce kernel. "run with Ray.")
if not self.disable_custom_all_reduce and self.world_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved.")
class SchedulerConfig: class SchedulerConfig:
...@@ -446,7 +530,11 @@ class SchedulerConfig: ...@@ -446,7 +530,11 @@ class SchedulerConfig:
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt max_model_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
max_paddings: Maximum number of paddings to be added to a batch. delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
""" """
def __init__( def __init__(
...@@ -454,7 +542,9 @@ class SchedulerConfig: ...@@ -454,7 +542,9 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int], max_num_batched_tokens: Optional[int],
max_num_seqs: int, max_num_seqs: int,
max_model_len: int, max_model_len: int,
max_paddings: int, use_v2_block_manager: bool = False,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
...@@ -464,7 +554,9 @@ class SchedulerConfig: ...@@ -464,7 +554,9 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_paddings = max_paddings self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager
self.chunked_prefill_enabled = enable_chunked_prefill
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -488,12 +580,12 @@ class DeviceConfig: ...@@ -488,12 +580,12 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
if torch.cuda.is_available(): if is_neuron():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
else: else:
raise RuntimeError("No supported device detected.") # We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly
self.device_type = device self.device_type = device
...@@ -505,10 +597,6 @@ class DeviceConfig: ...@@ -505,10 +597,6 @@ class DeviceConfig:
# Set device with device type # Set device with device type
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass @dataclass
class LoRAConfig: class LoRAConfig:
...@@ -558,6 +646,48 @@ class LoRAConfig: ...@@ -558,6 +646,48 @@ class LoRAConfig:
"LoRA is enabled.") "LoRA is enabled.")
@dataclass
class VisionLanguageConfig:
"""Configs the input data format and how models should run for
vision language models."""
class ImageInputType(enum.Enum):
"""Image input type into the vision language model.
An image roughly goes through the following transformation:
Raw image --> pixel values --> image features --> image embeddings.
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES = enum.auto()
IMAGE_FEATURES = enum.auto()
image_input_type: ImageInputType
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
@classmethod
def get_image_input_enum_type(
cls, value: str) -> "VisionLanguageConfig.ImageInputType":
"""Get the image input type from a string."""
try:
return cls.ImageInputType[value.upper()]
except KeyError as e:
raise ValueError(f"{value} is not a valid choice. "
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,
...@@ -602,7 +732,7 @@ def _get_and_verify_dtype( ...@@ -602,7 +732,7 @@ def _get_and_verify_dtype(
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE) if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
] ]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}") f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype. # Verify the dtype.
...@@ -635,15 +765,20 @@ def _get_and_verify_max_len( ...@@ -635,15 +765,20 @@ def _get_and_verify_max_len(
"max_seq_len", "max_seq_len",
# ChatGLM2 # ChatGLM2
"seq_length", "seq_length",
# Command-R
"model_max_length",
# Others # Others
"max_sequence_length", "max_sequence_length",
"max_seq_length", "max_seq_length",
"seq_len", "seq_len",
] ]
max_len_key = None
for key in possible_keys: for key in possible_keys:
max_len_key = getattr(hf_config, key, None) max_len = getattr(hf_config, key, None)
if max_len_key is not None: if max_len is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key) max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
if derived_max_model_len == float("inf"): if derived_max_model_len == float("inf"):
if max_model_len is not None: if max_model_len is not None:
# If max_model_len is specified, we use it. # If max_model_len is specified, we use it.
...@@ -669,10 +804,18 @@ def _get_and_verify_max_len( ...@@ -669,10 +804,18 @@ def _get_and_verify_max_len(
if max_model_len is None: if max_model_len is None:
max_model_len = derived_max_model_len max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
raise ValueError( # Some models might have a separate key for specifying model_max_length
f"User-specified max_model_len ({max_model_len}) is greater than " # that will be bigger than derived_max_model_len. We compare user input
f"the derived max_model_len ({max_len_key}={derived_max_model_len}" # with model_max_length and allow this override when it's smaller.
" in model's config.json). This may lead to incorrect model " model_max_length = getattr(hf_config, "model_max_length", None)
"outputs or CUDA errors. Make sure the value is correct and " if model_max_length is not None and max_model_len <= model_max_length:
"within the model context size.") pass
else:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len "
f"({max_len_key}={derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
return int(max_model_len) return int(max_model_len)
from typing import List, Optional
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
class BlockTable:
"""A class to manage blocks for a specific sequence.
The BlockTable maps a sequence of tokens to a list of blocks, where each
block represents a contiguous memory allocation for a portion of the
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
responsible for allocating and freeing memory for the blocks.
Args:
block_size (int): The maximum number of tokens that can be stored in a
single block.
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
single block.
_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]]): The list of blocks managed by this
BlockTable.
_num_full_slots (int): The number of tokens currently stored in the
blocks.
"""
def __init__(
self,
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
):
self._block_size = block_size
self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
"""
return cdiv(len(token_ids), block_size)
def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
sequence of token IDs.
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
"""
assert not self._is_allocated
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots = len(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
This method appends the given sequence of token IDs to the existing
blocks in the BlockTable. If there is not enough space in the existing
blocks, new blocks are allocated using the `ensure_num_empty_slots`
method to accommodate the additional tokens.
The token IDs are divided into chunks of size `block_size` (except for
the first chunk, which may be smaller), and each chunk is appended to a
separate block.
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
self.ensure_num_empty_slots(num_empty_slots=len(token_ids))
blocks = self._blocks[self._num_full_slots // self._block_size:]
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_ids[first_chunk_size:], self._block_size)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
This method checks if the BlockTable has enough empty slots (i.e.,
available space) to accommodate the requested number of tokens. If not,
it allocates additional blocks on the GPU to ensure that the required
number of empty slots is available.
Args:
num_empty_slots (int): The minimum number of empty slots required.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
device = Device.GPU
assert self._is_allocated
if self._num_empty_slots >= num_empty_slots:
return
slots_to_allocate = num_empty_slots - self._num_empty_slots
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
This method creates a new BlockTable instance with the same block size,
block allocator, and a copy of the blocks from the current instance. The
new BlockTable has its own independent set of blocks, but shares the
same underlying memory allocation with the original BlockTable.
Returns:
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert self._is_allocated
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
)
def free(self) -> None:
"""Frees the memory occupied by the blocks in the BlockTable.
This method iterates over all the blocks in the `_blocks` list and calls
the `free` method of the `_allocator` object to release the memory
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = None
@property
def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
This property returns a list of integers, where each integer represents
the physical block index of a corresponding block in the `_blocks` list.
The physical block index is a unique identifier for the memory location
occupied by the block.
Returns:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
blocks = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
prev_block, token_ids=block_token_ids, device=device)
else:
# Else, partially fill a mutable block with token ids.
prev_block = self._allocator.allocate_mutable(
prev_block=prev_block, device=device)
prev_block.append_token_ids(block_token_ids)
blocks.append(prev_block)
return blocks
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids = []
if not self._is_allocated:
return token_ids
for block in self._blocks:
token_ids.extend(block.token_ids)
return token_ids
@property
def _is_allocated(self) -> bool:
return self._blocks is not None
@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
return len(self._blocks) * self._block_size - self._num_full_slots
@property
def num_full_slots(self) -> int:
"""Returns the total number of tokens currently stored in the
BlockTable.
Returns:
int: The total number of tokens currently stored in the BlockTable.
"""
return self._num_full_slots
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
RefCount = int
class RefCounter:
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
corresponding reference counts. It provides methods to increment, decrement,
and retrieve the reference count for a given block index.
Args:
all_block_indices (Iterable[BlockId]): An iterable of block indices
to initialize the reference counter with.
"""
def __init__(self, all_block_indices: Iterable[BlockId]):
deduped = set(all_block_indices)
self._refcounts: Dict[BlockId,
RefCount] = {index: 0
for index in deduped}
def incr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
pre_incr_refcount = self._refcounts[block_id]
assert pre_incr_refcount >= 0
post_incr_refcount = pre_incr_refcount + 1
self._refcounts[block_id] = post_incr_refcount
return post_incr_refcount
def decr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
refcount = self._refcounts[block_id]
assert refcount > 0
refcount -= 1
self._refcounts[block_id] = refcount
return refcount
def get(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
return self._refcounts[block_id]
def as_readonly(self) -> "ReadOnlyRefCounter":
return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter:
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
reference counts maintained by a RefCounter instance. It does not allow
modifications to the reference counts.
Args:
refcounter (RefCounter): The RefCounter instance to create a read-only
view for.
"""
def __init__(self, refcounter: RefCounter):
self._refcounter = refcounter
def incr(self, block_id: BlockId) -> RefCount:
raise ValueError("Incr not allowed")
def decr(self, block_id: BlockId) -> RefCount:
raise ValueError("Decr not allowed")
def get(self, block_id: BlockId) -> RefCount:
return self._refcounter.get(block_id)
class CopyOnWriteTracker:
"""A class for tracking and managing copy-on-write operations for blocks.
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def __init__(
self,
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
block_id = block.block_id
if block_id is None:
return block_id
refcount = self._refcounter.get(block_id)
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
# Track src/dst copy.
self._copy_on_writes[src_block_id].append(block_id)
return block_id
def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
"""Clears the copy-on-write tracking information and returns the current
state.
This method returns a dictionary mapping source block indices to lists
of destination block indices for the current copy-on-write operations.
It then clears the internal tracking information.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices for the
current copy-on-write operations.
"""
cows = dict(self._copy_on_writes)
self._copy_on_writes.clear()
return cows
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
This function recursively traverses the sequence of blocks in reverse order,
starting from the given last block, and returns a list of all the blocks in
the sequence.
Args:
last_block (Block): The last block in the sequence.
Returns:
List[Block]: A list of all the blocks in the sequence, in the order they
appear.
"""
def recurse(block: Block, lst: List[Block]) -> None:
if block.prev_block is not None:
recurse(block.prev_block, lst)
lst.append(block)
all_blocks = []
recurse(last_block, all_blocks)
return all_blocks
from typing import Dict, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.utils import Device
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""A block allocator that can allocate blocks on both CPU and GPU memory.
This class implements the `DeviceAwareBlockAllocator` interface and provides
functionality for allocating and managing blocks of memory on both CPU and
GPU devices.
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
blocks, and allows for allocation, deallocation, forking, and swapping of
blocks across these memory pools.
"""
@staticmethod
def create(
allocator_type: str,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
This static method creates and returns a CpuGpuBlockAllocator instance
based on the provided parameters. It initializes the CPU and GPU block
allocators with the specified number of blocks, block size, and
allocator type.
Args:
allocator_type (str): The type of block allocator to use for CPU
and GPU blocks. Currently supported values are "naive" and
"prefix_caching".
num_gpu_blocks (int): The number of blocks to allocate for GPU
memory.
num_cpu_blocks (int): The number of blocks to allocate for CPU
memory.
block_size (int): The size of each block in number of tokens.
Returns:
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
specified configuration.
Notes:
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
gpu_block_ids = block_ids[:num_gpu_blocks]
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
elif allocator_type == "prefix_caching":
gpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
)
def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
), "cpu and gpu block allocators can't have intersection of block ids"
self._allocators = {
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self._block_ids_to_allocator = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable(
prev_block, token_ids)
def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.
Args:
block (Block): The block to be freed.
"""
allocator = self._block_ids_to_allocator[block.block_id]
return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator = self._block_ids_to_allocator[last_block.block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks.
Returns:
int: The number of free blocks available on the specified device.
"""
return self._allocators[device].get_num_free_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Returns:
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
of destination block IDs.
"""
# CoW only supported on GPU
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_computed(self) -> None:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed()
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
def all_block_ids(self) -> frozenset[int]:
return frozenset(self._block_ids_to_allocator.keys())
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from vllm.utils import Device
class Block(ABC):
@abstractmethod
def append_token_ids(self, token_ids: List[int]) -> None:
pass
@abstractproperty
def block_id(self) -> Optional[int]:
pass
@abstractproperty
def token_ids(self) -> List[int]:
pass
@abstractproperty
def num_empty_slots(self) -> int:
pass
@abstractproperty
def is_full(self) -> bool:
pass
@abstractproperty
def prev_block(self) -> Optional["Block"]:
pass
class Factory(Protocol):
@abstractmethod
def __call__(
self,
prev_block: Optional["Block"],
token_ids: List[int],
block_size: int,
allocator: "BlockAllocator",
block_id: Optional[int] = None,
) -> "Block":
pass
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractproperty
def all_block_ids(self) -> frozenset[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_computed(self) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
class NoFreeBlocksError(ValueError):
pass
class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
pass
from typing import Dict, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
Refcount = int
class NaiveBlockAllocator(BlockAllocator):
"""A simple block allocator that manages blocks of memory without prefix
caching.
Args:
create_block (Block.Factory): A factory function for creating new
blocks. This is used when a NaiveBlockAllocator is composed within
a prefix caching allocator -- the naive block allocator must
construct prefix caching blocks (but shouldn't know anything else
about them).
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def __init__(
self,
create_block: Block.Factory,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
token_ids (List[int]): The token IDs to be stored in the new block.
Returns:
Block: The newly allocated immutable block.
"""
block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
Returns:
Block: The newly allocated mutable block.
"""
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
token_ids=[],
block_id=block_id,
block_size=self._block_size,
allocator=self,
)
def free(self, block: Block) -> None:
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
block.block_id = None
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
# Increment refcount for each block.
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self) -> int:
return len(self._free_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
@property
def refcounter(self):
return self._refcounter
@property
def all_block_ids(self):
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
an empty list.
"""
return []
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
caching.
The NaiveBlock class represents a block of token IDs with a fixed size. It
provides methods for appending token IDs to the block and manages copy-on
-write operations when necessary.
Args:
prev_block (Block): The previous block in the sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
allocator (BlockAllocator): The block allocator associated with this
block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None, which means no allocation has been
made.
_cow_target (Optional[Block], optional): The copy-on-write target block.
If not provided, it defaults to self.
"""
def __init__(self,
prev_block: Block,
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
self._token_ids = []
self._block_size = block_size
self._prev_block = prev_block
self._block_id = block_id
self._allocator = allocator
self._cow_target = _cow_target if _cow_target is not None else self
self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
self._append_token_ids_no_cow(token_ids)
if self._block_id is not None:
self._block_id = (self._allocator.cow_block_if_not_appendable(
self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)
@property
def block_id(self) -> Optional[int]:
return self._block_id
@block_id.setter
def block_id(self, value: Optional[int]) -> None:
self._block_id = value
@property
def is_full(self) -> bool:
return self.num_empty_slots == 0
@property
def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids)
@property
def token_ids(self) -> List[int]:
return self._token_ids
def block_size(self) -> int:
return self._block_size
@property
def prev_block(self) -> Optional["Block"]:
return self._prev_block
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
PrefixHash = int
BlockId = int
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
content hash. It reuses blocks with the same content hash to avoid redundant
memory allocation. The allocator also supports copy-on-write operations.
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
# TODO last access time / evictor integration
def __init__(
self,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
# of self._cached_blocks.
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block,
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
)
self._block_size = block_size
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
) -> Block:
# Bind block to self.
allocator = self
return PrefixCachingBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
token_ids (List[int]): The token IDs to be stored in the block.
Returns:
Block: The allocated immutable block.
"""
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
allocator=self,
)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block.content_hash,
block.block_id)
return block
block = self.allocate_mutable(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
# TODO computed bit
return block
def allocate_mutable(self, prev_block: Block) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
Returns:
Block: The allocated mutable block.
"""
assert_prefix_caching_block_or_none(prev_block)
try:
return self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
if self._unused_cached_blocks:
# TODO policy for selecting block to remove
content_hash_to_evict = next(iter(self._unused_cached_blocks))
# Clear content hash mapping; the block will be overwritten.
del self._cached_blocks[content_hash_to_evict]
block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
refcount = self._refcounter.incr(block_id)
assert refcount == 1
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
)
assert block.content_hash is None
return block
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, content_hash: int,
block_id: BlockId) -> None:
refcount = self._refcounter.incr(block_id)
if refcount == 1:
assert content_hash in self._unused_cached_blocks
del self._unused_cached_blocks[content_hash]
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
"""
assert (block.block_id
is not None), "freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block)
block.block_id = None
def _free_block_id_for_block(self, block_id: BlockId,
block: Block) -> None:
assert isinstance(block, PrefixCachingBlock)
if block.content_hash is None:
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the unused cached blocks.
if refcount == 0:
assert block.content_hash not in self._unused_cached_blocks
assert block.content_hash in self._cached_blocks
self._unused_cached_blocks[block.content_hash] = block_id
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self) -> int:
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
return self._hashless_allocator.get_num_free_blocks() + len(
self._unused_cached_blocks)
@property
def all_block_ids(self) -> frozenset[int]:
return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self,
block: "PrefixCachingBlock") -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
Args:
block (PrefixCachingBlock): The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks:
self._cached_blocks[block.content_hash] = block.block_id
else:
self._free_block_id_for_block(block.block_id, block)
self._incr_refcount_cached_block(
block.content_hash, self._cached_blocks[block.content_hash])
return self._cached_blocks[block.content_hash]
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# TODO: Track computed blocks.
computed = lambda block_id: False
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
takewhile(lambda block_id: computed(block_id), seq[:-1])
for seq in seq_block_ids
]
return commonprefix([ids for ids in ids_list if ids != []])
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
The PrefixCachingBlock class represents a block of token IDs with prefix
caching capabilities. It wraps a NaiveBlock internally and provides
additional functionality for content hashing and promoting immutable blocks
with the prefix caching allocator.
Args:
prev_block (Optional[PrefixCachingBlock]): The previous block in the
sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
"""
def __init__(
self,
prev_block: Optional["PrefixCachingBlock"],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator,
block_id: Optional[int] = None,
):
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._block = NaiveBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=prefix_caching_allocator,
_cow_target=self,
)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert token_ids
# naive block handles CoW.
self._block.append_token_ids(token_ids)
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator.
promote_to_immutable_block(self))
@property
def block_id(self) -> Optional[int]:
return self._block.block_id
@block_id.setter
def block_id(self, value) -> None:
self._block.block_id = value
@property
def is_full(self) -> bool:
return self._block.is_full
@property
def num_empty_slots(self) -> int:
return self._block.num_empty_slots
@property
def block_size(self) -> int:
return self._block.block_size
@property
def token_ids(self) -> List[int]:
return self._block.token_ids
@property
def prev_block(self) -> Optional[Block]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined.
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
# We cannot compute a hash for the current block because it is not full.
if not self.is_full:
return None
is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else
self._prev_block.content_hash)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if prev_block_hash is None and not is_first_block:
return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block,
prev_block_hash,
cur_block_token_ids=self.token_ids)
return self._cached_content_hash
@staticmethod
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
cur_block_token_ids: List[int]) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
- prev_block_hash (Optional[int]): The hash of the previous block. None
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None:
return
assert isinstance(block, PrefixCachingBlock)
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
import enum from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
logger = init_logger(__name__)
class BlockAllocator:
class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass
@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass
@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
block = self.evictor.evict()
block.block_hash = block_hash
block.num_hashed_tokens = num_hashed_tokens
return block
block = PhysicalTokenBlock(device=self.device,
block_number=self.current_num_blocks,
block_size=self.block_size,
block_hash=block_hash,
num_hashed_tokens=num_hashed_tokens)
self.current_num_blocks += 1
return block
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
assert block.block_hash == block_hash
block.ref_count += 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
assert block.block_hash not in self.evictor
self.evictor.add(block)
# Remove the block from the cached_blocks
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)
def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# Update the hash of block and the cached_blocks dictionary.
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device. """Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when The allocator maintains a list of free blocks and allocates a block when
...@@ -30,10 +163,14 @@ class BlockAllocator: ...@@ -30,10 +163,14 @@ class BlockAllocator:
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock(device=device, block = PhysicalTokenBlock(device=device,
block_number=i, block_number=i,
block_size=block_size) block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block) self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock: def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks: if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.") raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop() block = self.free_blocks.pop()
...@@ -50,22 +187,16 @@ class BlockAllocator: ...@@ -50,22 +187,16 @@ class BlockAllocator:
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
return len(self.free_blocks) return len(self.free_blocks)
def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
class AllocStatus(enum.Enum): def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
"""Result for BlockSpaceManager.can_allocate raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager: class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks.""" """Manages the mapping between logical and physical token blocks."""
def __init__( def __init__(
...@@ -75,11 +206,16 @@ class BlockSpaceManager: ...@@ -75,11 +206,16 @@ class BlockSpaceManager:
num_cpu_blocks: int, num_cpu_blocks: int,
watermark: float = 0.01, watermark: float = 0.01,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks self.num_total_cpu_blocks = num_cpu_blocks
if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")
self.block_sliding_window = None self.block_sliding_window = None
if sliding_window is not None: if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window, assert sliding_window % block_size == 0, (sliding_window,
...@@ -89,11 +225,21 @@ class BlockSpaceManager: ...@@ -89,11 +225,21 @@ class BlockSpaceManager:
self.watermark = watermark self.watermark = watermark
assert watermark >= 0.0 assert watermark >= 0.0
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
num_gpu_blocks) if self.enable_caching:
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, logger.info("Automatic prefix caching is enabled.")
num_cpu_blocks) self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
...@@ -103,9 +249,6 @@ class BlockSpaceManager: ...@@ -103,9 +249,6 @@ class BlockSpaceManager:
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
if seq_group.prefix is not None and seq_group.prefix.allocated:
num_required_blocks -= seq_group.prefix.get_num_blocks()
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
self.block_sliding_window) self.block_sliding_window)
...@@ -129,36 +272,22 @@ class BlockSpaceManager: ...@@ -129,36 +272,22 @@ class BlockSpaceManager:
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = len(seq.logical_token_blocks)
block_table: BlockTable = [] block_table: BlockTable = []
prefix_block_table: BlockTable = []
num_prefix_blocks = 0
prefix = seq_group.prefix
if prefix is not None and prefix.allocated:
# Prefix has already been allocated. Use the existing block table.
num_prompt_blocks -= prefix.get_num_blocks()
for block in prefix.block_table:
block.ref_count += seq_group.num_seqs()
block_table.append(block)
for logical_idx in range(num_prompt_blocks): for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window] block = block_table[logical_idx % self.block_sliding_window]
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
else: else:
block = self.gpu_allocator.allocate() block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks. # Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs() block.ref_count = seq_group.num_seqs()
block_table.append(block) block_table.append(block)
if prefix is not None and not prefix.allocated:
# Allocate blocks for the prefix, we will compute the prefix's
# KV cache in this run.
num_prefix_blocks = prefix.get_num_blocks()
prefix_block_table = block_table[:num_prefix_blocks]
for block in prefix_block_table:
block.ref_count += 1
prefix.set_block_table(prefix_block_table)
# Assign the block table for each sequence. # Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[seq.seq_id] = block_table.copy()
...@@ -170,12 +299,83 @@ class BlockSpaceManager: ...@@ -170,12 +299,83 @@ class BlockSpaceManager:
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks return num_seqs <= num_free_gpu_blocks
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: def _promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
assert self.enable_caching
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block
# and return the cached version
if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash)
else:
self.gpu_allocator.update_hash(new_hash, last_block)
return last_block
def _is_last_block_full(
self,
seq: Sequence,
) -> bool:
token_ids_len = len(seq.data.get_token_ids())
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
def _maybe_promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
if self._is_last_block_full(seq):
return self._promote_last_block(seq, last_block)
else:
return last_block
def _allocate_last_physical_block(
self,
seq: Sequence,
) -> PhysicalTokenBlock:
# Called before a new block is appended.
# This is in charge of allocating a new physical block (to be appended).
# None if the last block is not full. Otherwise, we set it to the
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# prefix tokens)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
# If the block has is None, then the block is not full.
# If the block is not full, then we expect it to have a refcount of 1.
if block_hash is None:
assert new_block.ref_count == 1
return new_block
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token.""" """Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks): if len(block_table) < len(logical_blocks):
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
if (self.block_sliding_window if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window): and len(block_table) >= self.block_sliding_window):
# reuse a block # reuse a block
...@@ -184,8 +384,8 @@ class BlockSpaceManager: ...@@ -184,8 +384,8 @@ class BlockSpaceManager:
else: else:
# The sequence has a new logical block. # The sequence has a new logical block.
# Allocate a new physical block. # Allocate a new physical block.
block = self.gpu_allocator.allocate() new_block = self._allocate_last_physical_block(seq)
block_table.append(block) block_table.append(new_block)
return None return None
# We want to append the token to the last physical block. # We want to append the token to the last physical block.
...@@ -193,11 +393,18 @@ class BlockSpaceManager: ...@@ -193,11 +393,18 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU assert last_block.device == Device.GPU
if last_block.ref_count == 1: if last_block.ref_count == 1:
# Not shared with other sequences. Appendable. # Not shared with other sequences. Appendable.
if self.enable_caching:
# If the last block is now complete, we may reuse an old block
# to save memory.
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return None return None
else: else:
# The last block is shared with other sequences. # The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens. # Copy on Write: Allocate a new block and copy the tokens.
new_block = self.gpu_allocator.allocate() new_block = self._allocate_last_physical_block(seq)
block_table[-1] = new_block block_table[-1] = new_block
self.gpu_allocator.free(last_block) self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number return last_block.block_number, new_block.block_number
...@@ -207,7 +414,12 @@ class BlockSpaceManager: ...@@ -207,7 +414,12 @@ class BlockSpaceManager:
# Thus, it is always safe from OOM. # Thus, it is always safe from OOM.
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy() self.block_tables[child_seq.seq_id] = src_block_table.copy()
for block in src_block_table: # When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
# is only incremented by one, so we deduplicate them by wrapping
# them in a set.
for block in set(src_block_table):
block.ref_count += 1 block.ref_count += 1
def _get_physical_blocks( def _get_physical_blocks(
...@@ -233,25 +445,18 @@ class BlockSpaceManager: ...@@ -233,25 +445,18 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block. # CPU block -> GPU block.
if seq_group.prefix is not None:
# make sure to swap in the prefix first
assert seq_group.prefix.allocated and seq_group.prefix.computed
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
if seq_group.prefix is not None:
for block in seq_group.prefix.block_table:
new_block_table.append(block)
block.ref_count += 1
for cpu_block in block_table: for cpu_block in block_table:
if cpu_block in mapping: if cpu_block in mapping:
gpu_block = mapping[cpu_block] gpu_block = mapping[cpu_block]
gpu_block.ref_count += 1 gpu_block.ref_count += 1
else: else:
gpu_block = self.gpu_allocator.allocate() gpu_block = self.gpu_allocator.allocate(
cpu_block.block_hash, cpu_block.num_hashed_tokens)
mapping[cpu_block] = gpu_block mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block) new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU. # Free the CPU block swapped in to GPU.
...@@ -276,17 +481,12 @@ class BlockSpaceManager: ...@@ -276,17 +481,12 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table: for gpu_block in block_table:
if (seq_group.prefix is not None
and gpu_block in seq_group.prefix.block_table):
# NOTE: We do not swap out the prefix blocks for now.
self.gpu_allocator.free(gpu_block)
continue
if gpu_block in mapping: if gpu_block in mapping:
cpu_block = mapping[gpu_block] cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1 cpu_block.ref_count += 1
else: else:
cpu_block = self.cpu_allocator.allocate() cpu_block = self.cpu_allocator.allocate(
gpu_block.block_hash, gpu_block.num_hashed_tokens)
mapping[gpu_block] = cpu_block mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block) new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU. # Free the GPU block swapped out to CPU.
...@@ -300,7 +500,15 @@ class BlockSpaceManager: ...@@ -300,7 +500,15 @@ class BlockSpaceManager:
return block_number_mapping return block_number_mapping
def _free_block_table(self, block_table: BlockTable) -> None: def _free_block_table(self, block_table: BlockTable) -> None:
for block in set(block_table): # when using a sliding window, each seq will only use up
# to `self.block_sliding_window` blocks. When freeing
# the block table, we must make sure to not free blocks more
# than once. If no sliding window is used, there is no block
# reuse in the block table, so we must free all blocks.
blocks_to_free = (block_table[-self.block_sliding_window:]
if self.block_sliding_window is not None else
block_table)
for block in set(blocks_to_free):
if block.device == Device.GPU: if block.device == Device.GPU:
self.gpu_allocator.free(block) self.gpu_allocator.free(block)
else: else:
...@@ -328,3 +536,56 @@ class BlockSpaceManager: ...@@ -328,3 +536,56 @@ class BlockSpaceManager:
def get_num_free_cpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks() return self.cpu_allocator.get_num_free_blocks()
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
if self.enable_caching:
# Update the last accessed time of all the blocks accessed
# in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_full_blocks_in_seq(seq)
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional, Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
SeqId = int
class BlockSpaceManagerV2(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for
autoregressively-generated tokens, and other advanced features such as
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
The current implementation is partial; in particular prefix caching and
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
watermark (float, optional): The threshold used for memory swapping.
Defaults to 0.01.
sliding_window (Optional[int], optional): The size of the sliding
window. Defaults to None.
enable_caching (bool, optional): Flag indicating whether caching is
enabled. Defaults to False.
"""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
assert sliding_window is None, "Sliding window not yet supported"
self.block_sliding_window = None
self.watermark = watermark
assert watermark >= 0.0
assert not enable_caching, "Prefix caching not yet supported"
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create(
# Currently, only naive blocks are supported (no prefix caching).
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
)
assert self.block_sliding_window is None
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists"
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = waiting_seqs[0]
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
)
assert self.block_sliding_window is None
block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]
# Get unseen token ids.
num_full_slots = block_table.num_full_slots
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
assert unseen_token_ids
block_table.append_token_ids(unseen_token_ids)
# Return any copy-on-writes.
_ = self.block_allocator.clear_copy_on_writes()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
return None
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids
def access_all_blocks_in_seq(self, seq, now):
# TODO add prefix caching support.
# Tracked here https://github.com/vllm-project/vllm/issues/3667
pass
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# We ignore the sequence group as its not necessary. After the batch is
# formed by the scheduler, we do not need to mark blocks from individual
# sequence groups as computed -- all blocks in the batch can be marked
# as computed.
self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Currently, the attention implementation only supports skipping cached
blocks if they are a contiguous prefix of cached blocks.
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids)
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
return False
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
import enum
from abc import ABC, abstractmethod, abstractproperty
from typing import OrderedDict
from vllm.block import PhysicalTokenBlock
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_hash: int) -> bool:
pass
@abstractmethod
def evict(self) -> PhysicalTokenBlock:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@abstractmethod
def add(self, block: PhysicalTokenBlock):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _, block in self.free_table.items():
if evicted_block.last_accessed < block.last_accessed:
break
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block = block
self.free_table.pop(evicted_block.block_hash)
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
self.free_table.pop(block_hash)
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
import enum
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from vllm.sequence import Sequence, SequenceGroup
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager(ABC):
@staticmethod
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "v1":
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
return BlockSpaceManagerV1
if version == "v2":
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
pass
@abstractmethod
def allocate(self, seq_group: SequenceGroup) -> None:
pass
@abstractmethod
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
pass
@abstractmethod
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def free(self, seq: Sequence) -> None:
pass
@abstractmethod
def get_block_table(self, seq: Sequence) -> List[int]:
pass
@abstractmethod
def get_num_free_gpu_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_cpu_blocks(self) -> int:
pass
@abstractmethod
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
pass
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
from collections import deque
import enum import enum
import time import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.lora.request import LoRARequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
from vllm.prefix import PrefixPool
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,11 +28,24 @@ class PreemptionMode(enum.Enum): ...@@ -28,11 +28,24 @@ class PreemptionMode(enum.Enum):
RECOMPUTE = enum.auto() RECOMPUTE = enum.auto()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
seq_group: SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
class SchedulerOutputs: class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: Iterable[SequenceGroup], scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
prompt_run: bool, prompt_run: bool,
num_batched_tokens: int, num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
...@@ -40,17 +53,41 @@ class SchedulerOutputs: ...@@ -40,17 +53,41 @@ class SchedulerOutputs:
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
) -> None: ) -> None:
self.scheduled_seq_groups = scheduled_seq_groups """A list of sequence groups to be scheduled as a single batch.
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens Args:
self.blocks_to_swap_in = blocks_to_swap_in scheduled_seq_groups: A tuple of scheduled sequence group and its
self.blocks_to_swap_out = blocks_to_swap_out token chunk size.
self.blocks_to_copy = blocks_to_copy prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self.prompt_run: bool = prompt_run
# Total number of batched tokens.
self.num_batched_tokens: int = num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
# Sequence groups that are going to be ignored.
self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests) self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
...@@ -62,12 +99,11 @@ class SchedulerOutputs: ...@@ -62,12 +99,11 @@ class SchedulerOutputs:
def _sort_by_lora_ids(self) -> bool: def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted( self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups, self.scheduled_seq_groups,
key=lambda g: (g.lora_request.lora_int_id key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
if g.lora_request else 0, g.request_id))
@property @property
def lora_requests(self) -> Set[LoRARequest]: def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups} return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
class Scheduler: class Scheduler:
...@@ -90,15 +126,18 @@ class Scheduler: ...@@ -90,15 +126,18 @@ class Scheduler:
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs") self.policy = PolicyFactory.get_policy(policy_name="fcfs")
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config.
use_v2_block_manager else "v1")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks, num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window) sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
# Create the prefix pool to cache the prefixes.
self.prefix_pool = PrefixPool(self.cache_config.block_size)
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: Deque[SequenceGroup] = deque() self.waiting: Deque[SequenceGroup] = deque()
...@@ -107,6 +146,13 @@ class Scheduler: ...@@ -107,6 +146,13 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
self.prev_prompt = False
# Latency of the last prompt step
self.last_prompt_latency = 0.0
@property @property
def lora_enabled(self) -> bool: def lora_enabled(self) -> bool:
return bool(self.lora_config) return bool(self.lora_config)
...@@ -164,7 +210,7 @@ class Scheduler: ...@@ -164,7 +210,7 @@ class Scheduler:
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
# Fix the current time. # Fix the current time.
now = time.monotonic() now = time.time()
# Join waiting sequences if possible. # Join waiting sequences if possible.
if not self.swapped: if not self.swapped:
...@@ -177,24 +223,26 @@ class Scheduler: ...@@ -177,24 +223,26 @@ class Scheduler:
curr_loras = set( curr_loras = set(
seq_group.lora_int_id seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
leftover_waiting_sequences = deque() leftover_waiting_sequences = deque()
while self.waiting: num_batched_tokens = 0
while self._passed_delay(now) and self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs( waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING) status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
num_prompt_tokens = waiting_seqs[0].get_len() # get_len includes output tokens if the request has been
if num_prompt_tokens > self.prompt_limit: # preempted.
num_prefill_tokens = waiting_seqs[0].get_len()
if num_prefill_tokens > self.prompt_limit:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prefill_tokens} tokens) is too "
f" and exceeds limit of {self.prompt_limit}") f"long and exceeds limit of {self.prompt_limit}")
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -207,8 +255,8 @@ class Scheduler: ...@@ -207,8 +255,8 @@ class Scheduler:
break break
elif can_allocate == AllocStatus.NEVER: elif can_allocate == AllocStatus.NEVER:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prefill_tokens} tokens) is too "
f" and exceeds the capacity of block_manager") f"long and exceeds the capacity of block_manager")
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -218,8 +266,8 @@ class Scheduler: ...@@ -218,8 +266,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len( if (lora_int_id > 0 and lora_int_id not in curr_loras
curr_loras) >= self.lora_config.max_loras: and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group) leftover_waiting_sequences.appendleft(seq_group)
...@@ -227,8 +275,7 @@ class Scheduler: ...@@ -227,8 +275,7 @@ class Scheduler:
continue continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens += num_prefill_tokens
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens > if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
...@@ -240,27 +287,24 @@ class Scheduler: ...@@ -240,27 +287,24 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
if lora_int_id > 0: if lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
self.waiting.popleft() self.waiting.popleft()
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_prefill_tokens))
self.waiting.extendleft(leftover_waiting_sequences) self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups: if scheduled or ignored_seq_groups:
self.prev_prompt = True
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
prompt_run=True, prompt_run=True,
num_batched_tokens=len(seq_lens) * num_batched_tokens=num_batched_tokens,
max(seq_lens) if seq_lens else 0,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
...@@ -313,8 +357,8 @@ class Scheduler: ...@@ -313,8 +357,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len( if (lora_int_id > 0 and lora_int_id not in curr_loras
curr_loras) >= self.lora_config.max_loras: and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
leftover_swapped.appendleft(seq_group) leftover_swapped.appendleft(seq_group)
...@@ -350,7 +394,11 @@ class Scheduler: ...@@ -350,7 +394,11 @@ class Scheduler:
for seq_group in self.running) for seq_group in self.running)
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running, scheduled_seq_groups=[
ScheduledSequenceGroup(seq_group=running_group,
token_chunk_size=1)
for running_group in self.running
],
prompt_run=False, prompt_run=False,
num_batched_tokens=num_batched_tokens, num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
...@@ -369,15 +417,25 @@ class Scheduler: ...@@ -369,15 +417,25 @@ class Scheduler:
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
...@@ -385,17 +443,34 @@ class Scheduler: ...@@ -385,17 +443,34 @@ class Scheduler:
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
prefix=seq_group.prefix, computed_block_nums=common_computed_block_nums,
state=seq_group.state, state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.prompt_run else None,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq) self.block_manager.fork(parent_seq, child_seq)
def free_seq(self, seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
...@@ -458,7 +533,8 @@ class Scheduler: ...@@ -458,7 +533,8 @@ class Scheduler:
assert len(seqs) == 1 assert len(seqs) == 1
for seq in seqs: for seq in seqs:
seq.status = SequenceStatus.WAITING seq.status = SequenceStatus.WAITING
self.block_manager.free(seq) self.free_seq(seq)
seq.reset_state_for_recompute()
# NOTE: For FCFS, we insert the preempted sequence group to the front # NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue. # of the waiting queue.
self.waiting.appendleft(seq_group) self.waiting.appendleft(seq_group)
...@@ -496,3 +572,19 @@ class Scheduler: ...@@ -496,3 +572,19 @@ class Scheduler:
blocks_to_swap_out.update(mapping) blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED seq.status = SequenceStatus.SWAPPED
def _passed_delay(self, now: float) -> bool:
if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time
self.prev_time, self.prev_prompt = now, False
# Delay scheduling prompts to let waiting queue fill up
if self.scheduler_config.delay_factor > 0 and self.waiting:
earliest_arrival_time = min(
[e.metrics.arrival_time for e in self.waiting])
passed_delay = (
(now - earliest_arrival_time) >
(self.scheduler_config.delay_factor * self.last_prompt_latency)
or not self.running)
else:
passed_delay = True
return passed_delay
...@@ -3,8 +3,10 @@ import dataclasses ...@@ -3,8 +3,10 @@ import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig) ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.utils import str_to_int_tuple
@dataclass @dataclass
...@@ -25,11 +27,13 @@ class EngineArgs: ...@@ -25,11 +27,13 @@ class EngineArgs:
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
block_size: int = 16 block_size: int = 16
enable_prefix_caching: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
...@@ -38,6 +42,9 @@ class EngineArgs: ...@@ -38,6 +42,9 @@ class EngineArgs:
enforce_eager: bool = False enforce_eager: bool = False
max_context_len_to_capture: int = 8192 max_context_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
...@@ -45,6 +52,17 @@ class EngineArgs: ...@@ -45,6 +52,17 @@ class EngineArgs:
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -167,12 +185,24 @@ class EngineArgs: ...@@ -167,12 +185,24 @@ class EngineArgs:
help='load model sequentially in multiple batches, ' help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor ' 'to avoid RAM OOM when using tensor '
'parallel and large models') 'parallel and large models')
parser.add_argument(
'--ray-workers-use-nsight',
action='store_true',
help='If specified, use nsight to profile ray workers')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32, 128], choices=[8, 16, 32, 128],
help='token block size') help='token block size')
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
...@@ -188,6 +218,12 @@ class EngineArgs: ...@@ -188,6 +218,12 @@ class EngineArgs:
help='the fraction of GPU memory to be used for ' help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.' 'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.') 'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--forced-num-gpu-blocks',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens', parser.add_argument('--max-num-batched-tokens',
type=int, type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
...@@ -197,10 +233,12 @@ class EngineArgs: ...@@ -197,10 +233,12 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings', parser.add_argument(
type=int, '--max-logprobs',
default=EngineArgs.max_paddings, type=int,
help='maximum number of paddings in a batch') default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
...@@ -231,6 +269,25 @@ class EngineArgs: ...@@ -231,6 +269,25 @@ class EngineArgs:
action='store_true', action='store_true',
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig') help='See ParallelConfig')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
action='store_true', action='store_true',
...@@ -269,6 +326,43 @@ class EngineArgs: ...@@ -269,6 +326,43 @@ class EngineArgs:
default=EngineArgs.device, default=EngineArgs.device,
choices=["auto", "cuda", "neuron"], choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser.add_argument(
'--image-input-type',
type=str,
default=None,
choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM. '
'Should be one of "pixel_values" or "image_features".'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous'
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
return parser return parser
@classmethod @classmethod
...@@ -282,27 +376,39 @@ class EngineArgs: ...@@ -282,27 +376,39 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]: DeviceConfig, Optional[LoRAConfig],
Optional[VisionLanguageConfig]]:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format, self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization, self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture) self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window()) self.forced_num_gpu_blocks,
parallel_config = ParallelConfig(self.pipeline_parallel_size, model_config.get_sliding_window(),
self.tensor_parallel_size, self.enable_prefix_caching)
self.worker_use_ray, parallel_config = ParallelConfig(
self.max_parallel_loading_workers, self.pipeline_parallel_size, self.tensor_parallel_size,
self.disable_custom_all_reduce) self.worker_use_ray, self.max_parallel_loading_workers,
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.disable_custom_all_reduce,
self.max_num_seqs, TokenizerPoolConfig.create_config(
model_config.max_model_len, self.tokenizer_pool_size,
self.max_paddings) self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
...@@ -310,8 +416,25 @@ class EngineArgs: ...@@ -310,8 +416,25 @@ class EngineArgs:
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
)
else:
vision_language_config = None
return (model_config, cache_config, parallel_config, scheduler_config, return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config) device_config, lora_config, vision_language_config)
@dataclass @dataclass
......
import asyncio import asyncio
import os
import time import time
from functools import partial from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Union, AsyncIterator) Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray from vllm.engine.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
class AsyncEngineDeadError(RuntimeError): class AsyncEngineDeadError(RuntimeError):
pass pass
def _raise_exception_on_finish(task: asyncio.Task, def _raise_exception_on_finish(
request_tracker: "RequestTracker") -> None: task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! " msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.") "Please open an issue on Github.")
exception = None
try: try:
try: task.result()
task.result() # NOTE: This will be thrown if task exits normally (which it should not)
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg) raise AsyncEngineDeadError(msg)
except Exception as exc: except Exception as e:
request_tracker.propagate_exception(exc) exception = e
raise exc logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e
class AsyncStream: class AsyncStream:
...@@ -47,7 +55,7 @@ class AsyncStream: ...@@ -47,7 +55,7 @@ class AsyncStream:
self._queue = asyncio.Queue() self._queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: RequestOutput) -> None: def put(self, item: Union[RequestOutput, Exception]) -> None:
if self._finished: if self._finished:
return return
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -78,13 +86,13 @@ class RequestTracker: ...@@ -78,13 +86,13 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue() dict]] = asyncio.Queue()
self.new_requests_event = None self.new_requests_event = asyncio.Event()
def __contains__(self, item): def __contains__(self, item):
return item in self._request_streams return item in self._request_streams
def init_event(self): def __len__(self) -> int:
self.new_requests_event = asyncio.Event() return len(self._request_streams)
def propagate_exception(self, def propagate_exception(self,
exc: Exception, exc: Exception,
...@@ -93,9 +101,11 @@ class RequestTracker: ...@@ -93,9 +101,11 @@ class RequestTracker:
(all if request_id is None).""" (all if request_id is None)."""
if request_id is not None: if request_id is not None:
self._request_streams[request_id].put(exc) self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else: else:
for stream in self._request_streams.values(): for rid, stream in self._request_streams.items():
stream.put(exc) stream.put(exc)
self.abort_request(rid)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: RequestOutput,
...@@ -110,6 +120,17 @@ class RequestTracker: ...@@ -110,6 +120,17 @@ class RequestTracker:
logger.info(f"Finished request {request_id}.") logger.info(f"Finished request {request_id}.")
self.abort_request(request_id) self.abort_request(request_id)
def process_exception(self,
request_id: str,
exception: Exception,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream: **engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background """Add a request to be sent to the engine on the next background
...@@ -161,12 +182,15 @@ class RequestTracker: ...@@ -161,12 +182,15 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests return new_requests, finished_requests
async def wait_for_new_requests(self): async def wait_for_new_requests(self):
await self.new_requests_event.wait() if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
...@@ -186,17 +210,10 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -186,17 +210,10 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
all_outputs = await self._run_workers_async( output = await self.model_executor.execute_model_async(
"execute_model", seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
driver_kwargs={ scheduler_outputs.blocks_to_swap_out,
"seq_group_metadata_list": seq_group_metadata_list, scheduler_outputs.blocks_to_copy)
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else: else:
output = [] output = []
...@@ -225,7 +242,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -225,7 +242,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
...@@ -238,43 +255,16 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -238,43 +255,16 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
lora_request=lora_request) lora_request=lora_request)
return self.add_request( return self.add_request(request_id,
request_id, prompt=prompt,
prompt=prompt, prompt_token_ids=prompt_token_ids,
prompt_token_ids=prompt_token_ids, sampling_params=sampling_params,
sampling_params=sampling_params, arrival_time=arrival_time,
arrival_time=arrival_time, lora_request=lora_request,
lora_request=lora_request, multi_modal_data=multi_modal_data)
prefix_pos=prefix_pos,
)
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None: async def check_health_async(self) -> None:
driver_args = args self.model_executor.check_health()
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
class AsyncLLMEngine: class AsyncLLMEngine:
...@@ -326,27 +316,90 @@ class AsyncLLMEngine: ...@@ -326,27 +316,90 @@ class AsyncLLMEngine:
# collected # collected
self._background_loop_unshielded = None self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker() self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
if device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return (self.background_loop is not None return (self.background_loop is not None
and not self.background_loop.done()) and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
and self._background_loop_unshielded.done())
def get_tokenizer(self): @property
return self.engine.tokenizer.tokenizer def errored(self) -> bool:
return self._errored_with is not None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
else:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running: if self.is_running:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event() # Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop()) ).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, partial(_raise_exception_on_finish,
request_tracker=self._request_tracker)) error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, def _init_engine(self, *args,
...@@ -379,10 +432,18 @@ class AsyncLLMEngine: ...@@ -379,10 +432,18 @@ class AsyncLLMEngine:
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead # TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray: try:
await self.engine.add_request.remote(**new_request) if self.engine_use_ray:
else: await self.engine.add_request.remote(**new_request)
await self.engine.add_request_async(**new_request) else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)
if finished_requests: if finished_requests:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)
...@@ -406,12 +467,23 @@ class AsyncLLMEngine: ...@@ -406,12 +467,23 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False has_requests_in_progress = False
while True: while True:
if not has_requests_in_progress: if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step() logger.debug("Got new requests!")
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0) await asyncio.sleep(0)
async def add_request( async def add_request(
...@@ -422,7 +494,7 @@ class AsyncLLMEngine: ...@@ -422,7 +494,7 @@ class AsyncLLMEngine:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt shortened_prompt = prompt
...@@ -435,7 +507,6 @@ class AsyncLLMEngine: ...@@ -435,7 +507,6 @@ class AsyncLLMEngine:
max_log_len] max_log_len]
logger.info(f"Received request {request_id}: " logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, " f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling_params: {sampling_params}, " f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, " f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.") f"lora_request: {lora_request}.")
...@@ -473,7 +544,8 @@ class AsyncLLMEngine: ...@@ -473,7 +544,8 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prefix_pos=prefix_pos) multi_modal_data=multi_modal_data,
)
return stream return stream
...@@ -484,7 +556,7 @@ class AsyncLLMEngine: ...@@ -484,7 +556,7 @@ class AsyncLLMEngine:
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -500,11 +572,7 @@ class AsyncLLMEngine: ...@@ -500,11 +572,7 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix multi_modal_data: Multi modal data per request.
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine for the
...@@ -554,8 +622,7 @@ class AsyncLLMEngine: ...@@ -554,8 +622,7 @@ class AsyncLLMEngine:
>>> ... >>> ...
""" """
# Preprocess the request. # Preprocess the request.
# This should not be used for logging, as it is monotonic time. arrival_time = time.time()
arrival_time = time.monotonic()
try: try:
stream = await self.add_request( stream = await self.add_request(
...@@ -565,7 +632,7 @@ class AsyncLLMEngine: ...@@ -565,7 +632,7 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prefix_pos=prefix_pos, multi_modal_data=multi_modal_data,
) )
async for request_output in stream: async for request_output in stream:
...@@ -613,30 +680,24 @@ class AsyncLLMEngine: ...@@ -613,30 +680,24 @@ class AsyncLLMEngine:
else: else:
return self.engine.get_model_config() return self.engine.get_model_config()
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine
async def do_log_stats(self) -> None: async def do_log_stats(self) -> None:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.do_log_stats.remote() await self.engine.do_log_stats.remote()
else: else:
self.engine.do_log_stats() self.engine.do_log_stats()
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment