Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
......@@ -7,9 +7,17 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
if is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
class TorchSDPABackend(AttentionBackend):
......@@ -23,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -137,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
......@@ -150,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
"""
assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -197,13 +211,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
value[:, start:end, :],
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else:
......@@ -236,7 +251,7 @@ def _make_alibi_bias(
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
......@@ -248,7 +263,7 @@ def _make_alibi_bias(
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
......@@ -262,7 +277,7 @@ def _make_sliding_window_bias(
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
......
"""Attention backend utils"""
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
......@@ -6,10 +6,11 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -28,8 +29,8 @@ class XFormersBackend(AttentionBackend):
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
......@@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
......@@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor] = None
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
......@@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
......@@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
......@@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
def _get_attn_bias(
attn_metadata: XFormersMetadata,
attn_type: AttentionType,
) -> Optional[AttentionBias]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
else:
# attn_type == AttentionType.ENCODER_DECODER
return attn_metadata.cross_attn_bias
def _set_attn_bias(
attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
attn_metadata.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_seq_len_block_table_args(
attn_metadata: XFormersMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersImpl(AttentionImpl[XFormersMetadata]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......@@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor],
attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
shape = [num_tokens, num_heads * head_size]
"""
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, kv_scale)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
kv_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
......@@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# block tables are empty if the prompt does not have a cached
# prefix.
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta)
query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
......@@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_decode_seq_len,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......@@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
......@@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
"""
assert attn_metadata.seq_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
......@@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if attn_metadata.attn_bias is None:
attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if (attn_type == AttentionType.ENCODER_DECODER):
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None
# Default enc/dec cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
elif attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
# Default encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
else:
assert attn_metadata.seq_lens is not None
# Default decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_metadata.attn_bias = [attn_bias]
attn_bias = [attn_bias]
else:
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
assert attn_metadata.seq_lens is not None
attn_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
_set_attn_bias(attn_metadata, attn_bias, attn_type)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
......@@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query,
key,
value,
attn_bias=attn_metadata.attn_bias[0],
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale)
return out.view_as(original_query)
......@@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
assert attn_metadata.seq_lens is not None
output = torch.empty_like(original_query)
start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens):
......@@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_metadata.attn_bias[i],
attn_bias=attn_bias[i],
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
......@@ -431,8 +747,8 @@ def _make_alibi_bias(
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
) -> List[AttentionBias]:
attn_biases: List[AttentionBias] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
......
......@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
class Attention(nn.Module):
......@@ -56,15 +57,19 @@ class Attention(nn.Module):
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
assert isinstance(quant_method, Fp8KVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if "fp8" in self.kv_cache_dtype:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to self._kv_scale
# in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
......@@ -85,9 +90,16 @@ class Attention(nn.Module):
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._kv_scale)
return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._kv_scale,
attn_type=attn_type)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
......
......@@ -2,13 +2,14 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8)
and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
......@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale)
\ No newline at end of file
sm_scale=sm_scale)
......@@ -4,9 +4,35 @@
from functools import lru_cache
import numpy as np
import torch
import triton
from scipy import sparse
class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")
self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]
for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))
self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)
def dense_to_crow_col(x: torch.Tensor):
......@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)
......@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head(
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with torch.no_grad():
......@@ -148,10 +174,10 @@ def get_sparse_attn_mask(
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert dense_mask_type in ("binary", "bias")
......
from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules
import torch
from vllm import _custom_ops as ops
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
......@@ -5,6 +5,8 @@ import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
if triton.__version__ >= "2.1.0":
@triton.jit
......@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None,
sliding_window=None):
cap = torch.cuda.get_device_capability()
cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
......@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0":
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
......@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0":
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
......
......@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__)
......@@ -17,7 +17,10 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
@lru_cache(maxsize=None)
......@@ -57,15 +60,29 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
logger.info("Using IPEX attention backend.")
from vllm.attention.backends.ipex_attn import IpexAttnBackend
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
else:
raise ValueError("Invalid attention backend.")
......@@ -80,7 +97,6 @@ def which_attn_to_use(
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
......@@ -100,6 +116,21 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
......
......@@ -3,52 +3,9 @@ from typing import List
from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
def is_empty(self) -> bool:
return self.num_tokens == 0
def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
......
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
......@@ -23,6 +25,17 @@ logger = init_logger(__name__)
_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
]
class ModelConfig:
"""Configuration for the model.
......@@ -105,6 +118,7 @@ class ModelConfig:
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
......@@ -123,12 +137,10 @@ class ModelConfig:
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
if max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
......@@ -137,6 +149,17 @@ class ModelConfig:
code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
......@@ -144,6 +167,8 @@ class ModelConfig:
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = multimodal_config
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
......@@ -212,7 +237,7 @@ class ModelConfig:
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if (self.quantization
not in ["marlin", "gptq_marlin_24", "gptq_marlin"]):
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
......@@ -228,7 +253,8 @@ class ModelConfig:
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_text_config.num_attention_heads
total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
......@@ -236,13 +262,13 @@ class ModelConfig:
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
......@@ -251,8 +277,7 @@ class ModelConfig:
"BitAndBytes quantization with TP or PP is not supported yet.")
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
......@@ -307,7 +332,11 @@ class ModelConfig:
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]:
if self.hf_config.model_type == "mpt":
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type == "dbrx":
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
......@@ -341,12 +370,43 @@ class ModelConfig:
def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
class CacheConfig:
......@@ -611,45 +671,50 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from torch.cuda import device_count
from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray is not None
if device_count() < self.world_size:
ray_found = ray_utils.ray_is_available()
if cuda_device_count_stateless() < self.world_size:
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference")
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
backend = "ray"
elif ray_found:
from ray.util import get_current_placement_group
if self.placement_group or get_current_placement_group():
if self.placement_group:
backend = "ray"
else:
from ray import is_initialized as ray_is_initialized
if ray_is_initialized():
from ray.util import get_current_placement_group
if get_current_placement_group():
backend = "ray"
self.distributed_executor_backend = backend
logger.info("Defaulting to use %s for distributed inference",
backend)
self._verify_args()
self.rank = 0
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend == "mp"):
raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.")
if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if self.distributed_executor_backend == "ray":
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers "
......@@ -720,7 +785,6 @@ class SchedulerConfig:
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self._verify_args()
def _verify_args(self) -> None:
......@@ -754,8 +818,14 @@ class DeviceConfig:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu():
self.device_type = "cpu"
elif is_xpu():
self.device_type = "xpu"
else:
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
......@@ -765,8 +835,10 @@ class DeviceConfig:
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)
......@@ -785,6 +857,7 @@ class SpeculativeConfig:
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
......@@ -792,6 +865,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
......@@ -807,8 +883,11 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
......@@ -825,30 +904,37 @@ class SpeculativeConfig:
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if speculative_model is None and num_speculative_tokens is None:
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None
if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
assert (speculative_model is not None
and num_speculative_tokens is not None)
if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
......@@ -902,6 +988,25 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs,
)
draft_hf_config = draft_model_config.hf_config
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise ValueError(
"This speculative model supports a maximum of "
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
......@@ -911,7 +1016,19 @@ class SpeculativeConfig:
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
target_parallel_config,
speculative_draft_tensor_parallel_size))
if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
return SpeculativeConfig(
draft_model_config,
......@@ -920,6 +1037,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
)
@staticmethod
......@@ -959,16 +1081,26 @@ class SpeculativeConfig:
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"other value than 1")
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
......@@ -991,6 +1123,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
):
"""Create a SpeculativeConfig object.
......@@ -1004,6 +1139,19 @@ class SpeculativeConfig:
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
......@@ -1012,6 +1160,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self._verify_args()
......@@ -1023,6 +1176,31 @@ class SpeculativeConfig:
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if (self.draft_token_acceptance_method is None):
raise ValueError("draft_token_acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method !=
'typical_acceptance_sampler'):
raise ValueError(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.draft_token_acceptance_method}")
if (self.typical_acceptance_sampler_posterior_threshold < 0
or self.typical_acceptance_sampler_posterior_alpha < 0):
raise ValueError(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f"typical_acceptance_sampler_posterior_threshold = "
f"{self.typical_acceptance_sampler_posterior_threshold} and "
f"typical_acceptance_sampler_posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_alpha}")
@property
def num_lookahead_slots(self) -> int:
......@@ -1094,79 +1272,49 @@ class LoRAConfig:
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.")
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
@dataclass
class VisionLanguageConfig:
"""Configs the input data format and how models should run for
vision language models."""
class ImageInputType(enum.Enum):
"""Image input type into the vision language model.
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
An image roughly goes through the following transformation:
Raw image --> pixel values --> image features --> image embeddings.
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES = enum.auto()
IMAGE_FEATURES = enum.auto()
image_input_type: ImageInputType
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
# The image processor to load from HuggingFace
image_processor: Optional[str]
image_processor_revision: Optional[str]
@classmethod
def get_image_input_enum_type(cls, value: str) -> ImageInputType:
"""Get the image input type from a string."""
def __post_init__(self):
library_name = 'peft'
try:
return cls.ImageInputType[value.upper()]
except KeyError as e:
raise ValueError(f"{value} is not a valid choice. "
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
def get_image_token_text(
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Get the image token placeholder text to be inserted into the
text prompt and the string representation of the image token id.
"""
image_token_str = tokenizer.decode(self.image_token_id)
return image_token_str * self.image_feature_size, image_token_str
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
result["disable_image_processor"] = self.image_processor is None
return result
@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
multimodal models."""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = {
......@@ -1194,10 +1342,16 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
logger.info("Casting torch.float32 to torch.float16.")
torch_dtype = torch.float16
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
......@@ -1282,7 +1436,10 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] != "su" \
and rope_scaling["type"] != "longrope":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
......@@ -1357,6 +1514,17 @@ class DecodingConfig:
f"must be one of {valid_guided_backends}")
@dataclass
class ObservabilityConfig:
"""Configuration for observability."""
otlp_traces_endpoint: Optional[str] = None
def __post_init__(self):
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
raise ValueError("OpenTelemetry packages must be installed before "
"configuring 'otlp_traces_endpoint'")
@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
......@@ -1370,9 +1538,11 @@ class EngineConfig:
device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
multimodal_config: Optional[MultiModalConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
......@@ -1384,6 +1554,9 @@ class EngineConfig:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
......
from typing import List, Optional
from vllm.core.block.common import BlockList
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
......@@ -47,12 +48,10 @@ class BlockTable:
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
self._blocks: BlockList = BlockList(_blocks)
self._max_block_sliding_window = max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
self._num_full_slots = self._get_num_token_ids()
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
......@@ -88,11 +87,18 @@ class BlockTable:
"""
assert not self._is_allocated
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self.update(blocks)
self._num_full_slots = len(token_ids)
def update(self, blocks: List[Block]) -> None:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self._blocks.update(blocks)
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0,
......@@ -140,11 +146,11 @@ class BlockTable:
num_lookahead_slots)
# Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:]
first_block_idx = self._num_full_slots // self._block_size
token_blocks = self._chunk_token_blocks_for_append(token_ids)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block)
self._num_full_slots += len(token_ids)
......@@ -174,8 +180,8 @@ class BlockTable:
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
......@@ -209,12 +215,12 @@ class BlockTable:
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
for block in self.blocks:
self._allocator.free(block)
self._blocks = []
self._blocks.reset()
@property
def physical_block_ids(self) -> List[Optional[int]]:
def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
......@@ -228,7 +234,7 @@ class BlockTable:
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]
return self._blocks.ids()
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.
......@@ -252,18 +258,32 @@ class BlockTable:
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
blocks = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
prev_block, token_ids=block_token_ids, device=device)
blocks: List[Block] = []
block_token_ids = []
tail_token_ids = []
for cur_token_ids in chunk_list(token_ids, self._block_size):
if len(cur_token_ids) == self._block_size:
block_token_ids.append(cur_token_ids)
else:
# Else, partially fill a mutable block with token ids.
prev_block = self._allocator.allocate_mutable(
prev_block=prev_block, device=device)
prev_block.append_token_ids(block_token_ids)
blocks.append(prev_block)
tail_token_ids.append(cur_token_ids)
if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids,
device=device))
prev_block = blocks[-1]
if tail_token_ids:
assert len(tail_token_ids) == 1
cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device)
block.append_token_ids(cur_token_ids)
blocks.append(block)
return blocks
......@@ -274,18 +294,25 @@ class BlockTable:
if not self._is_allocated:
return token_ids
for block in self._blocks:
for block in self.blocks:
token_ids.extend(block.token_ids)
return token_ids
def _get_num_token_ids(self) -> int:
res = 0
for block in self.blocks:
res += len(block.token_ids)
return res
@property
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
@property
def blocks(self) -> Optional[List[Block]]:
return self._blocks
def blocks(self) -> List[Block]:
return self._blocks.list()
@property
def _num_empty_slots(self) -> int:
......
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator
......@@ -95,64 +96,40 @@ class CopyOnWriteTracker:
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
conjunction with a RefCounter.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def __init__(
self,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
def __init__(self, refcounter: RefCounterProtocol):
self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
def is_appendable(self, block: Block) -> bool:
"""Checks if the block is shared or not. If shared, then it cannot
be appended and needs to be duplicated via copy-on-write
"""
block_id = block.block_id
if block_id is None:
return block_id
return True
refcount = self._refcounter.get(block_id)
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
return refcount <= 1
# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes.append((src_block_id, block_id))
return block_id
def record_cow(self, src_block_id: Optional[BlockId],
trg_block_id: Optional[BlockId]) -> None:
"""Records a copy-on-write operation from source to target block id
Args:
src_block_id (BlockId): The source block id from which to copy
the data
trg_block_id (BlockId): The target block id to which the data
is copied
"""
assert src_block_id is not None
assert trg_block_id is not None
self._copy_on_writes.append((src_block_id, trg_block_id))
def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
"""Clears the copy-on-write tracking information and returns the current
......@@ -172,6 +149,139 @@ class CopyOnWriteTracker:
return cows
class BlockPool:
"""Used to pre-allocate block objects, in order to avoid excessive python
object allocations/deallocations.
The pool starts from "pool_size" objects and will increase to more objects
if necessary
Note that multiple block objects may point to the same physical block id,
which is why this pool is needed, so that it will be easier to support
prefix caching and more complicated sharing of physical blocks.
"""
def __init__(self, block_size: int, create_block: Block.Factory,
allocator: BlockAllocator, pool_size: int):
self._block_size = block_size
self._create_block = create_block
self._allocator = allocator
self._pool_size = pool_size
assert self._pool_size >= 0
self._free_ids: Deque[int] = deque(range(self._pool_size))
self._pool = []
for i in range(self._pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def increase_pool(self):
"""Doubles the internal pool size
"""
cur_pool_size = self._pool_size
new_pool_size = cur_pool_size * 2
self._pool_size = new_pool_size
self._free_ids += deque(range(cur_pool_size, new_pool_size))
for i in range(cur_pool_size, new_pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def init_block(self, prev_block: Optional[Block], token_ids: List[int],
block_size: int, physical_block_id: Optional[int]) -> Block:
if len(self._free_ids) == 0:
self.increase_pool()
assert len(self._free_ids) > 0
pool_id = self._free_ids.popleft()
block = self._pool[pool_id]
block.__init__( # type: ignore[misc]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id)
block.pool_id = pool_id # type: ignore[attr-defined]
return block
def free_block(self, block: Block) -> None:
self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
class BlockList:
"""This class is an optimization to allow fast-access to physical
block ids. It maintains a block id list that is updated with the
block list and this avoids the need to reconstruct the block id
list on every iteration of the block manager
"""
def __init__(self, blocks: List[Block]):
self._blocks: List[Block] = []
self._block_ids: List[int] = []
self.update(blocks)
def _add_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_ids.append(block_id)
def _update_block_id(self, block_index: int,
new_block_id: Optional[BlockId]) -> None:
assert new_block_id is not None
self._block_ids[block_index] = new_block_id
def update(self, blocks: List[Block]):
self._blocks = blocks
# Cache block ids for fast query
self._block_ids = []
for block in self._blocks:
self._add_block_id(block.block_id)
def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
block = self._blocks[block_index]
prev_block_id = block.block_id
block.append_token_ids(token_ids)
# CoW or promotion may update the internal block_id
if prev_block_id != block.block_id:
self._update_block_id(block_index, block.block_id)
def append(self, new_block: Block):
self._blocks.append(new_block)
self._add_block_id(new_block.block_id)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, block_index: int) -> Block:
return self._blocks[block_index]
def __setitem__(self, block_index: int, new_block: Block) -> None:
self._blocks[block_index] = new_block
self._update_block_id(block_index, new_block.block_id)
def reset(self):
self._blocks = []
self._block_ids = []
def list(self) -> List[Block]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
......
......@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def allocate_or_get_null_block(self) -> Block:
if self._null_block is None:
self._null_block = NullBlock(
self.allocate_mutable(None, Device.GPU))
self.allocate_mutable_block(None, Device.GPU))
return self._null_block
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
......@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable(prev_block)
return self._allocators[device].allocate_mutable_block(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids)
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
......@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable(
return self._allocators[device].allocate_immutable_block(
prev_block, token_ids)
def free(self, block: Block) -> None:
......@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block)
allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
......@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""
return self._allocators[device].get_physical_block_id(absolute_id)
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
def swap(self, blocks: List[Block], src_device: Device,
dst_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
......@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
blocks: List of blocks to be swapped.
source_device (Device): Device to swap the 'blocks' from.
dest_device (Device): Device to swap the 'blocks' to.
src_device (Device): Device to swap the 'blocks' from.
dst_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
source_block_ids = [block.block_id for block in blocks]
self._allocators[source_device].swap_out(blocks)
self._allocators[dest_device].swap_in(blocks)
dest_block_ids = [block.block_id for block in blocks]
src_block_ids = [block.block_id for block in blocks]
self._allocators[src_device].swap_out(blocks)
self._allocators[dst_device].swap_in(blocks)
dst_block_ids = [block.block_id for block in blocks]
current_swap_mapping: Dict[int, int] = {}
for src, dest in zip(source_block_ids, dest_block_ids):
if src is not None and dest is not None:
self._swap_mapping[src] = dest
current_swap_mapping[src] = dest
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
if src_block_id is not None and dst_block_id is not None:
self._swap_mapping[src_block_id] = dst_block_id
current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping
def get_num_blocks_touched(self,
......@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_computed_block_ids(
prev_computed_block_ids, block_ids, skip_last_block_id)
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
computed_seq_block_ids)
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
......@@ -341,6 +364,11 @@ class NullBlock(Block):
def token_ids(self) -> List[BlockId]:
return self._proxy.token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for null block")
@property
def num_empty_slots(self) -> BlockId:
return self._proxy.num_empty_slots
......
......@@ -28,6 +28,13 @@ class Block(ABC):
def token_ids(self) -> List[int]:
pass
@property
@abstractmethod
def num_tokens_total(self) -> int:
"""The number of tokens till the current block (inclusive)
"""
pass
@property
@abstractmethod
def num_empty_slots(self) -> int:
......@@ -92,12 +99,18 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
pass
@abstractmethod
def allocate_immutable_blocks(
self, prev_block: Optional[Block],
block_token_ids: List[List[int]]) -> List[Block]:
pass
@abstractmethod
......@@ -146,13 +159,19 @@ class BlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
......@@ -174,13 +193,20 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device) -> List[Block]:
pass
@abstractmethod
......@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
......@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
pass
@abstractmethod
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
def swap(self, blocks: List[Block], src_device: Device,
dst_device: Device) -> Dict[int, int]:
pass
@abstractmethod
......
from typing import FrozenSet, Iterable, List, Optional, Set, Tuple
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv
......@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator):
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
block_pool: Optional[BlockPool] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids)
self._free_block_indices: Deque[BlockId] = deque(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
refcounter=self._refcounter.as_readonly())
if block_pool is None:
extra_factor = 4
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
self._block_pool = BlockPool(self._block_size, create_block, self,
num_blocks * extra_factor)
else:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self._block_pool = block_pool
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
......@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable(prev_block=prev_block)
block = self.allocate_mutable_block(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
assert device is None
num_blocks = len(block_token_ids)
block_ids = []
for i in range(num_blocks):
block_ids.append(self._allocate_block_id())
blocks = []
for i in range(num_blocks):
prev_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block_token_ids[i],
block_size=self._block_size,
physical_block_id=block_ids[i])
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
......@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated mutable block.
"""
assert device is None
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
token_ids=[],
block_id=block_id,
block_size=self._block_size,
allocator=self,
)
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
block_id = self._allocate_block_id()
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
return block
def _allocate_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = self._free_block_indices.popleft()
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
# Release the block object
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
......@@ -111,7 +165,7 @@ class NaiveBlockAllocator(BlockAllocator):
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
forked_blocks: List[Block] = []
prev_block = None
for block in source_blocks:
......@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_size=self._block_size,
physical_block_id=block.block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
......@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
......@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
......@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
......@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
"""
pass
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
"""No prefix caching here => return empty list
"""
return []
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
......@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
raise NotImplementedError("There is no promotion for naive blocks")
def get_num_blocks_touched(self,
blocks: List[Block],
......@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
def swap_out(self, blocks: List[Block]) -> None:
for block in blocks:
self.free(block)
self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None:
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
tmp_block.block_id = None
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class NaiveBlock(Block):
......@@ -315,11 +382,12 @@ class NaiveBlock(Block):
self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
"""Appends the given token IDs to the block and performs a
copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
"""
self._append_token_ids_no_cow(token_ids)
......@@ -328,7 +396,16 @@ class NaiveBlock(Block):
self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
"""Appends the given token IDs to the block
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
if len(token_ids) == 0:
return
assert len(token_ids) <= self.num_empty_slots
self._token_ids.extend(token_ids)
@property
......@@ -361,12 +438,17 @@ class NaiveBlock(Block):
@property
def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids)
return self._block_size - len(self.token_ids)
@property
def token_ids(self) -> List[int]:
return self._token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for naive block")
@property
def block_size(self) -> int:
return self._block_size
......
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv
......@@ -19,6 +19,30 @@ PrefixHash = int
_DEFAULT_LAST_ACCESSED_TIME = -1
class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
"""
__slots__ = ("active", "last_accessed", "computed")
def reset(self):
self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self.computed: bool = False
def __init__(self):
self.active: bool = False
self.reset()
def enable(self):
assert not self.active
self.active = True
self.reset()
def disable(self):
assert self.active
self.active = False
self.reset()
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
......@@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
if block_ids is None:
block_ids = range(num_blocks)
self._block_size = block_size
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of blockId to Block to track those cached blocks
self._blocks: Dict[BlockId, Block] = {}
# Used to track status of each physical block id
self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids:
self._block_tracker[block_id] = BlockTracker()
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
extra_factor = 4
self._block_pool = BlockPool(self._block_size, self._create_block,
self, num_blocks * extra_factor)
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
......@@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
block_pool=self._block_pool, # Share block pool here
)
self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
......@@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
refcounter=self._refcounter.as_readonly())
# Implements Block.Factory.
def _create_block(
......@@ -90,14 +125,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
allocator=allocator,
computed=computed,
)
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
......@@ -111,29 +146,41 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
allocator=self,
)
# First, try to create a block that points to cached data
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
physical_block_id=None)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block, block.block_id)
self._incr_refcount_cached_block(block)
return block
self._block_pool.free_block(block)
block = self.allocate_mutable(prev_block)
# No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
return block
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
blocks = []
for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids,
device=device)
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
......@@ -147,113 +194,154 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None
assert_prefix_caching_block_or_none(prev_block)
try:
block = self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
block_id = self._allocate_block_id()
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
assert not block.computed
assert block.content_hash is None
return block
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
def _incr_refcount_cached_block(self, block: Block) -> None:
# Set this block to be "computed" since it is pointing to a
# cached block id (which was already computed)
block.computed = True
# If the evictor has blocks available for eviction, evict a block
# and return it.
if self.evictor.num_blocks > 0:
# here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
block_id = block.block_id
assert block_id is not None
_block_id = self._cached_blocks[content_hash_to_evict]
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
refcount = self._refcounter.incr(block_id)
if refcount == 1:
# In case a cached block was evicted, restore its tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._cached_blocks.pop(content_hash_to_evict)
self._track_block_id(block_id, computed=True)
self._refcounter.incr(block_id)
def _decr_refcount_cached_block(self, block: Block) -> None:
# Ensure this is immutable/cached block
assert block.content_hash is not None
# the block comes from evictor already contain computed result
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
computed=True,
)
assert block.content_hash is None
block_id = block.block_id
assert block_id is not None
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
refcount = self._refcounter.decr(block_id)
if refcount > 0:
block.block_id = None
return
else:
assert refcount == 0
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
# No longer used
assert block.content_hash in self._cached_blocks
def _incr_refcount_cached_block(self, block: Block,
block_id: BlockId) -> None:
# now _incr_refcount_cached_block comes from two place
# allocate_immutable/promote_to_immutable_block where hit
# _cached_blocks hash key.
# In both cases, it means that already exists a already
# computed block which shared with block now
block.computed = True
# Add the cached block to the evictor
# (This keeps the cached block around so it can be reused)
self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
self._block_tracker[block_id].last_accessed)
refcount = self._refcounter.incr(block_id)
# Stop tracking the block
self._untrack_block_id(block_id)
block.block_id = None
def _decr_refcount_hashless_block(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
# We may have a fork case where block is shared,
# in which case, we cannot remove it from tracking
refcount = self._refcounter.get(block_id)
if refcount == 1:
# if block get referred, then it shall not be in evictor
# and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
self._untrack_block_id(block_id)
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
# Decrement refcount of the block_id, but do not free the block object
# itself (will be handled by the caller)
self._hashless_allocator.free(block, keep_block_object=True)
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
def _allocate_block_id(self) -> BlockId:
"""First tries to allocate a block id from the hashless allocator,
and if there are no blocks, then tries to evict an unused cached block.
"""
assert (block.block_id
is not None), "freeing unallocated block is undefined"
hashless_block_id = self._maybe_allocate_hashless_block_id()
if hashless_block_id is not None:
return hashless_block_id
self._free_block_id_for_block(block.block_id, block)
evicted_block_id = self._maybe_allocate_evicted_block_id()
if evicted_block_id is not None:
return evicted_block_id
block.block_id = None
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
try:
# Allocate mutable block and extract its block_id
block = self._hashless_allocator.allocate_mutable_block(
prev_block=None)
block_id = block.block_id
self._block_pool.free_block(block)
self._track_block_id(block_id, computed=False)
return block_id
except BlockAllocator.NoFreeBlocksError:
return None
def _free_block_id_for_block(self, block_id: BlockId,
block: Block) -> None:
assert isinstance(block, PrefixCachingBlock)
# if we comes from promote_to_immutable_block, it means that
# block.content_hash is never None.
# However we need to release the same content block, so that
# physical block could get reused.
if block.block_id != block_id or block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
assert block.block_id is not None
refcount = self._refcounter.get(block.block_id)
if refcount == 1:
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
if self.evictor.num_blocks == 0:
return None
refcount = self._refcounter.decr(block_id)
# Here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
# Sanity checks
assert content_hash_to_evict in self._cached_blocks
_block_id = self._cached_blocks[content_hash_to_evict]
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
# If no longer used, add the block to the evictor.
if refcount == 0:
assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
self._cached_blocks.pop(content_hash_to_evict)
self._refcounter.incr(block_id)
self._track_block_id(block_id, computed=False)
return block_id
def _free_block_id(self, block: Block) -> None:
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
block_id = block.block_id
assert block_id is not None, "Freeing unallocated block is undefined"
if block.content_hash is not None:
# Immutable: This type of block is always cached, and we want to
# keep it in the evictor for future reuse
self._decr_refcount_cached_block(block)
else:
# Mutable: This type of block is not cached, so we release it
# directly to the hashless allocator
self._decr_refcount_hashless_block(block)
assert block.block_id is None
def free(self, block: Block, keep_block_object: bool = False) -> None:
"""Release the block (look at free_block_id(..) docs)
"""
# Release the physical block index
self._free_block_id(block)
# Release the block object to the pool
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
......@@ -268,20 +356,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
forked_blocks: List[Block] = []
prev_block = None
for block in source_blocks:
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.incr(block_id)
assert refcount != 1, "can't fork free'd block_id = {}".format(
block_id)
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_size=self._block_size,
physical_block_id=block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
......@@ -326,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
block id.
Args:
block: The mutable block to be promoted.
......@@ -335,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator):
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
# Ensure block can be promoted
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached
# (Note that this block is not computed yet =>
# Will be computed after free())
self._cached_blocks[block.content_hash] = block.block_id
else:
self._free_block_id_for_block(
self._cached_blocks[block.content_hash], block)
self._incr_refcount_cached_block(
block, self._cached_blocks[block.content_hash])
return block.block_id
# Reuse the cached content hash
self._decr_refcount_hashless_block(block)
block.block_id = self._cached_blocks[block.content_hash]
return self._cached_blocks[block.content_hash]
# Increment refcount of the cached block and (possibly) restore
# it from the evictor.
# Note that in this case, the block is marked as computed
self._incr_refcount_cached_block(block)
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
return block.block_id
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
......@@ -359,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
......@@ -383,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
for block_id in block_ids:
if block_id in self._blocks:
self._blocks[block_id].last_accessed = now
if self._block_tracker[block_id].active:
self._block_tracker[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
......@@ -392,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
raise NotImplementedError("Marking as computed is incremental")
for block_id in block_ids:
if block_id in self._blocks:
# only those full block is valid for prefix caching
if self._blocks[block_id].is_full:
self._blocks[block_id].computed = True
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which "
"is not belonged to GPU")
def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:
assert block_id is not None
self._block_tracker[block_id].enable()
self._block_tracker[block_id].computed = computed
def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_tracker[block_id].disable()
def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks:
return self._blocks[block_id].computed
if self._block_tracker[block_id].active:
return self._block_tracker[block_id].computed
else:
return block_id in self.evictor
def get_computed_block_ids(self,
prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool = True) -> List[int]:
prev_prefix_size = len(prev_computed_block_ids)
cur_size = len(block_ids)
if skip_last_block_id:
cur_size -= 1
# Sanity checks
assert cur_size >= 0
assert prev_prefix_size <= cur_size
ret = prev_computed_block_ids
for i in range(prev_prefix_size, cur_size):
block_id = block_ids[i]
if self.block_is_computed(block_id):
ret.append(block_id)
return ret
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked
......@@ -421,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
list(
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
# It returns a list of int although type annotation says list of string.
return commonprefix([
ids for ids in ids_list # type: ignore
ids for ids in computed_seq_block_ids # type: ignore
if ids != []
])
......@@ -470,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped out.
"""
for block in blocks:
self.free(block)
self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None:
"""Execute the swap int actions. Change the block id from
"""Execute the swap in actions. Change the block id from
old allocator to current allocator for each block to finish
the block table update.
......@@ -481,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped in.
"""
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class PrefixCachingBlock(Block):
......@@ -504,7 +638,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (BlockAllocator): The prefix
allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
......@@ -515,31 +649,55 @@ class PrefixCachingBlock(Block):
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: BlockAllocator,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator. Got instead allocator = {}".format(
allocator))
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._cached_num_tokens_total: int = 0
self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._block = NaiveBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=prefix_caching_allocator,
_cow_target=self,
)
# On the first time, we create the block object, and next we only
# reinitialize it
if hasattr(self, "_block"):
self._block.__init__( # type: ignore[has-type]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
else:
self._block = NaiveBlock(prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
self._update_num_tokens_total()
def _update_num_tokens_total(self):
"""Incrementally computes the number of tokens that there is
till the current block (included)
"""
res = 0
# Add all previous blocks
if self._prev_block is not None:
res += self._prev_block.num_tokens_total
# Add current block
res += len(self.token_ids)
self._cached_num_tokens_total = res
@property
def computed(self) -> bool:
......@@ -561,22 +719,28 @@ class PrefixCachingBlock(Block):
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert token_ids
# Ensure this is mutable block (not promoted)
assert self.content_hash is None
assert not self.computed
if len(token_ids) == 0:
return
# naive block handles CoW.
# Ensure there are input tokens
assert token_ids, "Got token_ids = {}".format(token_ids)
# Naive block handles CoW.
self._block.append_token_ids(token_ids)
self._update_num_tokens_total()
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator.
promote_to_immutable_block(self))
self.block_id = self._allocator.promote_to_immutable_block(self)
@property
def block_id(self) -> Optional[int]:
......@@ -596,23 +760,6 @@ class PrefixCachingBlock(Block):
@property
def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total
@property
......@@ -635,7 +782,6 @@ class PrefixCachingBlock(Block):
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
......@@ -685,7 +831,129 @@ class PrefixCachingBlock(Block):
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
def __init__(self, allocator):
self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
bool]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._cached_computed_seq_blocks
self._cached_computed_seq_blocks[seq_id] = ([], False)
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._cached_computed_seq_blocks
del self._cached_computed_seq_blocks[seq_id]
def get_cached_computed_blocks_and_update(
self, seq_id: int, block_ids: List[int]) -> List[int]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]
if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0
if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks
# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
return computed_block_ids
class LastAccessBlocksTracker:
"""Manages the last access time of the tracked sequences, in order to allow
an efficient update of allocator's block last access times
"""
def __init__(self, allocator):
self._allocator = allocator
self._seq_last_access: Dict[int, Optional[float]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._seq_last_access
self._seq_last_access[seq_id] = None
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._seq_last_access
del self._seq_last_access[seq_id]
def update_last_access(self, seq_id: int, time: float) -> None:
assert seq_id in self._seq_last_access
self._seq_last_access[seq_id] = time
def update_seq_blocks_last_access(self, seq_id: int,
block_ids: List[int]) -> None:
assert seq_id in self._seq_last_access
ts = self._seq_last_access[seq_id]
if ts is None:
# No last access was recorded, no need to update.
return
self._allocator.mark_blocks_as_accessed(block_ids, ts)
def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None:
return
assert isinstance(block, PrefixCachingBlock)
assert isinstance(block,
PrefixCachingBlock), "Got block = {}".format(block)
......@@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else seq.n_blocks
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
......@@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
num_prompt_blocks = seq.n_blocks
block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks):
......@@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
new_hash = seq.hash_of_block(seq.n_blocks - 1)
# if new_hash is already in the cached table, then free last_block
# and return the cached version
......@@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
......@@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
assert len(block_table) == n_blocks - 1
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
......@@ -472,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
......
......@@ -7,6 +7,8 @@ from typing import Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
......@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
......@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table: BlockTable = self._allocate_sequence(seq)
self.block_tables[seq.seq_id] = block_table
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
......@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return new_cows
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
seq_id = seq.seq_id
if seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
# Update seq block ids with the latest access time
self._last_access_blocks_tracker.update_seq_blocks_last_access(
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
# Untrack seq
self._last_access_blocks_tracker.remove_seq(seq_id)
self._computed_blocks_tracker.remove_seq(seq_id)
# Free table/blocks
self.block_tables[seq_id].free()
del self.block_tables[seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id
......@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
del self.cross_block_tables[request_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
......@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching:
block_table = self.block_tables[seq.seq_id]
block_ids = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
# Record the latest access time for the sequence. The actual update
# of the block ids is deferred to the sequence free(..) call, since
# only during freeing of block ids, the blocks are actually added to
# the evictor (which is when the most updated time is required)
# (This avoids expensive calls to mark_blocks_as_accessed(..))
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching,
......@@ -285,17 +304,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
computed_seq_block_ids = []
for seq in seqs:
computed_seq_block_ids.append(
self._computed_blocks_tracker.
get_cached_computed_blocks_and_update(
seq.seq_id,
self.block_tables[seq.seq_id].physical_block_ids))
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) # type: ignore
computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
"""Returns the AllocStatus for the given sequence_group
......@@ -323,19 +354,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
"""
blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU)
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id):
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id)
for cpu_block_id, gpu_block_id in current_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.CPU,
dst_device=Device.GPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id):
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id)
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
"""Returns whether we can swap out the given sequence_group
......@@ -355,7 +398,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return True
return False
def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]:
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
......@@ -366,19 +409,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
"""
blocks = self._get_blocks_for_swap(sequence_group,
SequenceStatus.RUNNING)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU)
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id):
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id)
for gpu_block_id, cpu_block_id in current_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.GPU,
dst_device=Device.CPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id):
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id)
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
......
......@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
......@@ -50,8 +51,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
......@@ -65,28 +66,28 @@ class SchedulingBudget:
return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens:
if req_id in self._request_ids_num_batched_tokens:
return
self._requeset_ids_num_batched_tokens.add(req_id)
self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens:
self._requeset_ids_num_batched_tokens.remove(req_id)
if req_id in self._request_ids_num_batched_tokens:
self._request_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens
def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs:
if req_id in self._request_ids_num_curr_seqs:
return
self._requeset_ids_num_curr_seqs.add(req_id)
self._request_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs:
self._requeset_ids_num_curr_seqs.remove(req_id)
if req_id in self._request_ids_num_curr_seqs:
self._request_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs
@property
......@@ -139,6 +140,8 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
......@@ -157,6 +160,14 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
......@@ -256,6 +267,7 @@ class Scheduler:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
......@@ -273,11 +285,19 @@ class Scheduler:
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
......@@ -290,7 +310,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
......@@ -364,6 +387,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running(
self,
running_queue: deque,
......@@ -1006,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
seq_group_metadata_list.append(seq_group_metadata)
......@@ -1027,6 +1057,11 @@ class Scheduler:
self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment