Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
...@@ -7,9 +7,17 @@ import torch ...@@ -7,9 +7,17 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import PagedAttentionMetadata
PagedAttentionMetadata) from vllm.utils import is_cpu
if is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
...@@ -23,8 +31,8 @@ class TorchSDPABackend(AttentionBackend): ...@@ -23,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return TorchSDPABackendImpl return TorchSDPABackendImpl
@staticmethod @staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata(*args, **kwargs) return TorchSDPAMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -137,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -137,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -150,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -150,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
...@@ -197,13 +211,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -197,13 +211,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata.attn_bias): attn_metadata.attn_bias):
end = start + seq_len end = start + seq_len
sub_out = scaled_dot_product_attention( sub_out = scaled_dot_product_attention(
query[:, start:end, :], query[None, :, start:end, :],
key[:, start:end, :], key[None, :, start:end, :],
value[:, start:end, :], value[None, :, start:end, :],
attn_mask=mask, attn_mask=mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=not self.need_mask, is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0) scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out output[start:end, :, :] = sub_out
start = end start = end
else: else:
...@@ -236,7 +251,7 @@ def _make_alibi_bias( ...@@ -236,7 +251,7 @@ def _make_alibi_bias(
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: List[int], seq_lens: List[int],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens: for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
...@@ -248,7 +263,7 @@ def _make_alibi_bias( ...@@ -248,7 +263,7 @@ def _make_alibi_bias(
num_heads = alibi_slopes.shape[0] num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)) bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty( inf_mask = torch.empty(
(1, seq_len, seq_len), (1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
...@@ -262,7 +277,7 @@ def _make_sliding_window_bias( ...@@ -262,7 +277,7 @@ def _make_sliding_window_bias(
window_size: Optional[int], window_size: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens: for seq_len in seq_lens:
tensor = torch.full( tensor = torch.full(
(1, seq_len, seq_len), (1, seq_len, seq_len),
......
"""Attention backend utils"""
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
...@@ -6,10 +6,11 @@ import torch ...@@ -6,10 +6,11 @@ import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias, from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask, BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -28,8 +29,8 @@ class XFormersBackend(AttentionBackend): ...@@ -28,8 +29,8 @@ class XFormersBackend(AttentionBackend):
return XFormersImpl return XFormersImpl
@staticmethod @staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata": def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata(*args, **kwargs) return XFormersMetadata
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API. updated from `CUDAGraphRunner.forward` API.
""" """
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
...@@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------| # |-------------------- seq_len ----------------------|
# |-- query_len ---| # |-- query_len ---|
# Maximum query length in the batch. None for decoding. # seq_lens stored as a tensor.
max_query_len: Optional[int] seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn. # FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
...@@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum sequence length among decode batch. 0 if there are prefill # Maximum sequence length among decode batch. 0 if there are prefill
# requests only. # requests only.
max_decode_seq_len: int max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # Whether or not if cuda graph is enabled.
# is [4, 6], it is [0, 4, 10]. # Cuda-graph is currently enabled for decoding only.
query_start_loc: Optional[torch.Tensor] # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn. # FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in # (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is # the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10]. # [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed # (batch_size,) A tensor of context lengths (tokens that are computed
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] = None
# Whether or not if cuda graph is enabled. # Maximum query length in the batch. None for decoding.
# Cuda-graph is currently enabled for decoding only. max_query_len: Optional[int] = None
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersMetadata"] = None _cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None _cached_decode_metadata: Optional["XFormersMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
...@@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API. # from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["XFormersMetadata"]: def prefill_metadata(self) -> Optional["XFormersMetadata"]:
...@@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None return None
if self._cached_prefill_metadata is not None: if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata return self._cached_prefill_metadata
assert self.seq_lens is not None assert ((self.seq_lens is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens is not None))
assert self.query_start_loc is not None assert ((self.seq_lens_tensor is not None)
assert self.context_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
assert self.block_tables is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersMetadata( self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=seq_lens,
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=query_start_loc,
seq_start_loc=None, context_lens_tensor=context_lens_tensor,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=block_tables,
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
...@@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return None return None
if self._cached_decode_metadata is not None: if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata return self._cached_decode_metadata
assert self.block_tables is not None assert ((self.seq_lens_tensor is not None)
assert self.seq_lens_tensor is not None or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata( self._cached_decode_metadata = XFormersMetadata(
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=slot_mapping,
seq_lens=None, seq_lens_tensor=seq_lens_tensor,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, block_tables=block_tables,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
) # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata return self._cached_decode_metadata
def _get_attn_bias(
attn_metadata: XFormersMetadata,
attn_type: AttentionType,
) -> Optional[AttentionBias]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
else:
# attn_type == AttentionType.ENCODER_DECODER
return attn_metadata.cross_attn_bias
def _set_attn_bias(
attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
attn_metadata.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_seq_len_block_table_args(
attn_metadata: XFormersMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersImpl(AttentionImpl[XFormersMetadata]): class XFormersImpl(AttentionImpl[XFormersMetadata]):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
...@@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: Optional[torch.Tensor],
value: torch.Tensor, value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0, kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: # Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache. if (key is not None) and (value is not None):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. if attn_type == AttentionType.ENCODER_DECODER:
PagedAttention.write_to_paged_cache(key, value, key_cache, # Update cross-attention KV cache (prefill-only)
value_cache, # During cross-attention decode, key & value will be None,
attn_metadata.slot_mapping, # preventing this IF-statement branch from running
self.kv_cache_dtype, kv_scale) updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
num_prefill_tokens = attn_metadata.num_prefill_tokens # Update self-attention KV cache (prefill/decode)
num_decode_tokens = attn_metadata.num_decode_tokens updated_slot_mapping = attn_metadata.slot_mapping
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
kv_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query) output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens] if key is not None and value is not None:
value = value[:num_prefill_tokens] key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
...@@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# block tables are empty if the prompt does not have a cached # block tables are empty if the prompt does not have a cached
# prefix. # prefix.
out = self._run_memory_efficient_xformers_forward( out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta) query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_tokens].shape assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
else: else:
assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
...@@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, block_tables_arg,
decode_meta.seq_lens_tensor, seq_lens_arg,
decode_meta.max_decode_seq_len, max_seq_len_arg,
self.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: XFormersMetadata, attn_metadata: XFormersMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
...@@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_prefill_tokens, num_kv_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
""" """
assert attn_metadata.seq_lens is not None
original_query = query original_query = query
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K]. # GQA/MQA requires the shape [B, M, G, H, K].
...@@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
if attn_metadata.attn_bias is None: attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None: if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens( if (attn_type == AttentionType.ENCODER_DECODER):
attn_metadata.seq_lens) assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None
# Default enc/dec cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
elif attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
# Default encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
else:
assert attn_metadata.seq_lens is not None
# Default decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention( attn_bias = attn_bias.make_local_attention(
self.sliding_window) self.sliding_window)
attn_metadata.attn_bias = [attn_bias] attn_bias = [attn_bias]
else: else:
attn_metadata.attn_bias = _make_alibi_bias( assert attn_metadata.seq_lens is not None
self.alibi_slopes, self.num_kv_heads, query.dtype, attn_bias = _make_alibi_bias(self.alibi_slopes,
attn_metadata.seq_lens) self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
_set_attn_bias(attn_metadata, attn_bias, attn_type)
# No alibi slopes. # No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
...@@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query, query,
key, key,
value, value,
attn_bias=attn_metadata.attn_bias[0], attn_bias=attn_bias[0],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
return out.view_as(original_query) return out.view_as(original_query)
...@@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# FIXME(woosuk): Because xformers does not support dynamic sequence # FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
assert attn_metadata.seq_lens is not None
output = torch.empty_like(original_query) output = torch.empty_like(original_query)
start = 0 start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens): for i, seq_len in enumerate(attn_metadata.seq_lens):
...@@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query[None, start:end], query[None, start:end],
key[None, start:end], key[None, start:end],
value[None, start:end], value[None, start:end],
attn_bias=attn_metadata.attn_bias[i], attn_bias=attn_bias[i],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
...@@ -431,8 +747,8 @@ def _make_alibi_bias( ...@@ -431,8 +747,8 @@ def _make_alibi_bias(
num_kv_heads: int, num_kv_heads: int,
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: List[int], seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias: ) -> List[AttentionBias]:
attn_biases = [] attn_biases: List[AttentionBias] = []
for seq_len in seq_lens: for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
......
...@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional ...@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
class Attention(nn.Module): class Attention(nn.Module):
...@@ -56,15 +57,19 @@ class Attention(nn.Module): ...@@ -56,15 +57,19 @@ class Attention(nn.Module):
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self) if quant_config else None self) if quant_config else None
if quant_method is not None: if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2": assert isinstance(quant_method, Fp8KVCacheMethod)
raise ValueError("fp8_e5m2 kv-cache is not supported with " # TODO (mgoin): kv cache dtype should be specified in the FP8
"fp8 checkpoints.") # checkpoint config and become the "auto" behavior
# When FP8 quantization is enabled, we make a parameter if "fp8" in self.kv_cache_dtype:
# "kv_scale" so that it can be loaded from FP8 checkpoint. if self.kv_cache_dtype == "fp8_e5m2":
# The kv_scale will then be converted back raise ValueError("fp8_e5m2 kv-cache is not supported with "
# to self._kv_scale in a native float32 value after weight loading. "fp8 checkpoints.")
self.quant_method = quant_method # When FP8 quantization is enabled, we make a parameter
self.quant_method.create_weights(self) # "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to self._kv_scale
# in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
...@@ -85,9 +90,16 @@ class Attention(nn.Module): ...@@ -85,9 +90,16 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._kv_scale) return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._kv_scale,
attn_type=attn_type)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
......
...@@ -2,13 +2,14 @@ import math ...@@ -2,13 +2,14 @@ import math
import torch import torch
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8) and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE: if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
...@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module): ...@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
v, v,
cu_seqlens_k, cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale) sm_scale=sm_scale)
\ No newline at end of file
...@@ -4,9 +4,35 @@ ...@@ -4,9 +4,35 @@
from functools import lru_cache from functools import lru_cache
import numpy as np
import torch import torch
import triton import triton
from scipy import sparse
class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")
self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]
for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))
self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)
def dense_to_crow_col(x: torch.Tensor): def dense_to_crow_col(x: torch.Tensor):
...@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor): ...@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert x.dim() in (2, 3) assert x.dim() in (2, 3)
if x.dim() == 2: if x.dim() == 2:
x = x[None] x = x[None]
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x] x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x] cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols) max_cols = max(len(xi) for xi in cols)
...@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head( ...@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head(
): ):
""" """
:return: a tuple of 3: :return: a tuple of 3:
- tuple of crow_indices, col_indices representation - tuple of crow_indices, col_indices representation
of CSR format. of CSR format.
- block dense mask - block dense mask
- all token dense mask (be aware that it can be - all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`, OOM if it is too big) if `return_dense==True`,
otherwise, None otherwise, None
""" """
with torch.no_grad(): with torch.no_grad():
...@@ -148,10 +174,10 @@ def get_sparse_attn_mask( ...@@ -148,10 +174,10 @@ def get_sparse_attn_mask(
:param dense_mask_type: "binary" (0 for skip token, 1 for others) :param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others) or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3: :return: a tuple of 3:
- tuple of crow_indices, col_indices representation - tuple of crow_indices, col_indices representation
of CSR format. of CSR format.
- block dense mask - block dense mask
- all token dense mask (be aware that it can be OOM if it - all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None is too big) if `return_dense==True`, otherwise, None
""" """
assert dense_mask_type in ("binary", "bias") assert dense_mask_type in ("binary", "bias")
......
from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules
import torch
from vllm import _custom_ops as ops
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
...@@ -5,6 +5,8 @@ import torch ...@@ -5,6 +5,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.platforms import current_platform
if triton.__version__ >= "2.1.0": if triton.__version__ >= "2.1.0":
@triton.jit @triton.jit
...@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0": ...@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
cap = torch.cuda.get_device_capability() cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64 BLOCK = 128 if cap[0] >= 8 else 64
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
...@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0": ...@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0":
b_ctx_len, b_ctx_len,
alibi_slopes, alibi_slopes,
v_cache.shape[3], v_cache.shape[3],
8, k_cache.shape[4],
o, o,
b_loc.stride(0), b_loc.stride(0),
b_loc.stride(1), b_loc.stride(1),
...@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0": ...@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0":
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
v_cache.shape[3], v_cache.shape[3],
8, k_cache.shape[4],
o, o,
b_loc.stride(0), b_loc.stride(0),
b_loc.stride(1), b_loc.stride(1),
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -17,7 +17,10 @@ class _Backend(enum.Enum): ...@@ -17,7 +17,10 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
...@@ -57,15 +60,29 @@ def get_attn_backend( ...@@ -57,15 +60,29 @@ def get_attn_backend(
ROCmFlashAttentionBackend) ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA: elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
logger.info("Using IPEX attention backend.")
from vllm.attention.backends.ipex_attn import IpexAttnBackend
return IpexAttnBackend
elif backend == _Backend.FLASHINFER: elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.") logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
else: else:
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
...@@ -80,7 +97,6 @@ def which_attn_to_use( ...@@ -80,7 +97,6 @@ def which_attn_to_use(
block_size: int, block_size: int,
) -> _Backend: ) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
...@@ -100,6 +116,21 @@ def which_attn_to_use( ...@@ -100,6 +116,21 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
if is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip(): if is_hip():
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
......
...@@ -3,52 +3,9 @@ from typing import List ...@@ -3,52 +3,9 @@ from typing import List
from vllm.utils import Device from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1 DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
def is_empty(self) -> bool:
return self.num_tokens == 0
def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock: class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache.""" """Represents the state of a block in the KV cache."""
......
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
Union)
import torch import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -23,6 +25,17 @@ logger = init_logger(__name__) ...@@ -23,6 +25,17 @@ logger = init_logger(__name__)
_GB = 1 << 30 _GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
]
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
...@@ -105,6 +118,7 @@ class ModelConfig: ...@@ -105,6 +118,7 @@ class ModelConfig:
disable_sliding_window: bool = False, disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -123,12 +137,10 @@ class ModelConfig: ...@@ -123,12 +137,10 @@ class ModelConfig:
self.quantization = quantization self.quantization = quantization
self.quantization_param_path = quantization_param_path self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture if max_context_len_to_capture is not None:
if self.max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. " raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.") "Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture self.max_seq_len_to_capture = max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
...@@ -137,6 +149,17 @@ class ModelConfig: ...@@ -137,6 +149,17 @@ class ModelConfig:
code_revision, rope_scaling, rope_theta) code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True
self.max_model_len = _get_and_verify_max_len( self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config, hf_config=self.hf_text_config,
max_model_len=max_model_len, max_model_len=max_model_len,
...@@ -144,6 +167,8 @@ class ModelConfig: ...@@ -144,6 +167,8 @@ class ModelConfig:
sliding_window_len=self.get_hf_config_sliding_window()) sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model, self.served_model_name = get_served_model_name(model,
served_model_name) served_model_name)
self.multimodal_config = multimodal_config
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_embedding_mode() self._verify_embedding_mode()
...@@ -212,7 +237,7 @@ class ModelConfig: ...@@ -212,7 +237,7 @@ class ModelConfig:
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if (self.quantization if (self.quantization
not in ["marlin", "gptq_marlin_24", "gptq_marlin"]): not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
...@@ -228,7 +253,8 @@ class ModelConfig: ...@@ -228,7 +253,8 @@ class ModelConfig:
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
) -> None: ) -> None:
total_num_attention_heads = self.hf_text_config.num_attention_heads total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0: if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError( raise ValueError(
...@@ -236,13 +262,13 @@ class ModelConfig: ...@@ -236,13 +262,13 @@ class ModelConfig:
" must be divisible by tensor parallel size " " must be divisible by tensor parallel size "
f"({tensor_parallel_size}).") f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0: architectures = getattr(self.hf_config, "architectures", [])
raise ValueError( if not all(arch in _PP_SUPPORTED_MODELS
f"Total number of hidden layers ({total_num_hidden_layers}) " for arch in architectures) and pipeline_parallel_size > 1:
"must be divisible by pipeline parallel size " raise NotImplementedError(
f"({pipeline_parallel_size}).") "Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if self.quantization == "bitsandbytes" and ( if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1 parallel_config.tensor_parallel_size > 1
...@@ -251,8 +277,7 @@ class ModelConfig: ...@@ -251,8 +277,7 @@ class ModelConfig:
"BitAndBytes quantization with TP or PP is not supported yet.") "BitAndBytes quantization with TP or PP is not supported yet.")
def get_hf_config_sliding_window(self) -> Optional[int]: def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled. """Get the sliding window size, or None if disabled."""
"""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present # addition to sliding window size. We check if that field is present
...@@ -307,7 +332,11 @@ class ModelConfig: ...@@ -307,7 +332,11 @@ class ModelConfig:
return 1 return 1
# For DBRX and MPT # For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]: if self.hf_config.model_type == "mpt":
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type == "dbrx":
return getattr(self.hf_config.attn_config, "kv_n_heads", return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads) self.hf_config.num_attention_heads)
...@@ -341,12 +370,43 @@ class ModelConfig: ...@@ -341,12 +370,43 @@ class ModelConfig:
def get_num_attention_heads(self, def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int: parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \ num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
parallel_config.tensor_parallel_size return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers from vllm.distributed.utils import get_pp_indices
return total_num_hidden_layers // parallel_config.pipeline_parallel_size total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
class CacheConfig: class CacheConfig:
...@@ -611,45 +671,50 @@ class ParallelConfig: ...@@ -611,45 +671,50 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
from torch.cuda import device_count
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray is not None ray_found = ray_utils.ray_is_available()
if device_count() < self.world_size: if cuda_device_count_stateless() < self.world_size:
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
"required for multi-node inference") "required for multi-node inference, "
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
backend = "ray" backend = "ray"
elif ray_found: elif ray_found:
from ray.util import get_current_placement_group if self.placement_group:
if self.placement_group or get_current_placement_group():
backend = "ray" backend = "ray"
else:
from ray import is_initialized as ray_is_initialized
if ray_is_initialized():
from ray.util import get_current_placement_group
if get_current_placement_group():
backend = "ray"
self.distributed_executor_backend = backend self.distributed_executor_backend = backend
logger.info("Defaulting to use %s for distributed inference", logger.info("Defaulting to use %s for distributed inference",
backend) backend)
self._verify_args() self._verify_args()
self.rank = 0
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1: if (self.pipeline_parallel_size > 1
raise NotImplementedError( and self.distributed_executor_backend == "mp"):
"Pipeline parallelism is not supported yet.") raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.")
if self.distributed_executor_backend not in ("ray", "mp", None): if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend. Supported values " "Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.") "are 'ray' or 'mp'.")
if not self.disable_custom_all_reduce and self.world_size > 1: if self.distributed_executor_backend == "ray":
if is_hip(): from vllm.executor import ray_utils
self.disable_custom_all_reduce = True ray_utils.assert_ray_available()
logger.info( if is_hip():
"Disabled the custom all-reduce kernel because it is not " self.disable_custom_all_reduce = True
"supported on AMD GPUs.") logger.info(
elif self.pipeline_parallel_size > 1: "Disabled the custom all-reduce kernel because it is not "
self.disable_custom_all_reduce = True "supported on AMD GPUs.")
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if self.ray_workers_use_nsight and ( if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"): not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
...@@ -720,7 +785,6 @@ class SchedulerConfig: ...@@ -720,7 +785,6 @@ class SchedulerConfig:
self.chunked_prefill_enabled = enable_chunked_prefill self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -754,8 +818,14 @@ class DeviceConfig: ...@@ -754,8 +818,14 @@ class DeviceConfig:
# Automated device type detection # Automated device type detection
if is_neuron(): if is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu(): elif is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
elif is_xpu():
self.device_type = "xpu"
else: else:
# We don't call torch.cuda.is_available() here to # We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked # avoid initializing CUDA before workers are forked
...@@ -765,8 +835,10 @@ class DeviceConfig: ...@@ -765,8 +835,10 @@ class DeviceConfig:
self.device_type = device self.device_type = device
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu") self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
else: else:
# Set device with device type # Set device with device type
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
...@@ -785,6 +857,7 @@ class SpeculativeConfig: ...@@ -785,6 +857,7 @@ class SpeculativeConfig:
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
target_dtype: str, target_dtype: str,
speculative_model: Optional[str], speculative_model: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int], num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
...@@ -792,6 +865,9 @@ class SpeculativeConfig: ...@@ -792,6 +865,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
...@@ -807,8 +883,11 @@ class SpeculativeConfig: ...@@ -807,8 +883,11 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model. target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative speculative_model (Optional[str]): The name of the speculative
model, if provided. model, if provided.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip the speculative model. Used when testing the ability to skip
speculation for some sequences. speculation for some sequences.
...@@ -825,30 +904,37 @@ class SpeculativeConfig: ...@@ -825,30 +904,37 @@ class SpeculativeConfig:
window, if provided. window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided. window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None. the necessary conditions are met, else None.
""" """
if speculative_model is None and num_speculative_tokens is None: if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None return None
if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
if (speculative_disable_by_batch_size is not None if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2): and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling " raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}") f"{speculative_disable_by_batch_size=}")
assert (speculative_model is not None
and num_speculative_tokens is not None)
if enable_chunked_prefill: if enable_chunked_prefill:
raise ValueError( raise ValueError(
"Speculative decoding and chunked prefill are " "Speculative decoding and chunked prefill are "
...@@ -902,6 +988,25 @@ class SpeculativeConfig: ...@@ -902,6 +988,25 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
) )
draft_hf_config = draft_model_config.hf_config
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise ValueError(
"This speculative model supports a maximum of "
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
draft_model_config.max_model_len = ( draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len( SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len, speculative_max_model_len,
...@@ -911,7 +1016,19 @@ class SpeculativeConfig: ...@@ -911,7 +1016,19 @@ class SpeculativeConfig:
draft_parallel_config = ( draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config)) target_parallel_config,
speculative_draft_tensor_parallel_size))
if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")
if typical_acceptance_sampler_posterior_threshold is None:
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
return SpeculativeConfig( return SpeculativeConfig(
draft_model_config, draft_model_config,
...@@ -920,6 +1037,11 @@ class SpeculativeConfig: ...@@ -920,6 +1037,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
ngram_prompt_lookup_max, ngram_prompt_lookup_max,
ngram_prompt_lookup_min, ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
) )
@staticmethod @staticmethod
...@@ -959,16 +1081,26 @@ class SpeculativeConfig: ...@@ -959,16 +1081,26 @@ class SpeculativeConfig:
@staticmethod @staticmethod
def create_draft_parallel_config( def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig: target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker. """Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the This is mostly a copy of the target parallel config, except the tp_size.
draft worker can have a different parallel strategy, e.g. TP=1.
""" """
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"other value than 1")
draft_parallel_config = ParallelConfig( draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config. pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size, pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config. distributed_executor_backend=target_parallel_config.
distributed_executor_backend, distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config. max_parallel_loading_workers=target_parallel_config.
...@@ -991,6 +1123,9 @@ class SpeculativeConfig: ...@@ -991,6 +1123,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1004,6 +1139,19 @@ class SpeculativeConfig: ...@@ -1004,6 +1139,19 @@ class SpeculativeConfig:
enqueue requests is larger than this value. enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window. ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window. ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1012,6 +1160,11 @@ class SpeculativeConfig: ...@@ -1012,6 +1160,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self._verify_args() self._verify_args()
...@@ -1023,6 +1176,31 @@ class SpeculativeConfig: ...@@ -1023,6 +1176,31 @@ class SpeculativeConfig:
if self.draft_model_config: if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config( self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config) self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if (self.draft_token_acceptance_method is None):
raise ValueError("draft_token_acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method !=
'typical_acceptance_sampler'):
raise ValueError(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.draft_token_acceptance_method}")
if (self.typical_acceptance_sampler_posterior_threshold < 0
or self.typical_acceptance_sampler_posterior_alpha < 0):
raise ValueError(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f"typical_acceptance_sampler_posterior_threshold = "
f"{self.typical_acceptance_sampler_posterior_threshold} and "
f"typical_acceptance_sampler_posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_alpha}")
@property @property
def num_lookahead_slots(self) -> int: def num_lookahead_slots(self) -> int:
...@@ -1094,79 +1272,49 @@ class LoRAConfig: ...@@ -1094,79 +1272,49 @@ class LoRAConfig:
"Due to limitations of the custom LoRA CUDA kernel, " "Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when " "max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.") "LoRA is enabled.")
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
@dataclass @dataclass
class VisionLanguageConfig: class PromptAdapterConfig:
"""Configs the input data format and how models should run for max_prompt_adapters: int
vision language models.""" max_prompt_adapter_token: int
max_cpu_prompt_adapters: Optional[int] = None
class ImageInputType(enum.Enum): prompt_adapter_dtype: Optional[torch.dtype] = None
"""Image input type into the vision language model.
An image roughly goes through the following transformation: def __post_init__(self):
Raw image --> pixel values --> image features --> image embeddings. library_name = 'peft'
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES = enum.auto()
IMAGE_FEATURES = enum.auto()
image_input_type: ImageInputType
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
# The image processor to load from HuggingFace
image_processor: Optional[str]
image_processor_revision: Optional[str]
@classmethod
def get_image_input_enum_type(cls, value: str) -> ImageInputType:
"""Get the image input type from a string."""
try: try:
return cls.ImageInputType[value.upper()] __import__(library_name)
except KeyError as e: except ImportError as e:
raise ValueError(f"{value} is not a valid choice. " raise ImportError(
f"Expecting to choose from " f"'{library_name}' is not installed for prompt adapter support."
f"{[x.name for x in cls.ImageInputType]}.") from e f"Please install it using 'pip install {library_name}'."
) from e
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class. if self.max_prompt_adapters < 1:
def get_image_token_text( raise ValueError(f"max_prompt_adapters "
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: f"({self.max_prompt_adapters}) must be >= 1.")
"""Get the image token placeholder text to be inserted into the if self.max_prompt_adapter_token == 0:
text prompt and the string representation of the image token id. raise ValueError("max_prompt_adapter_token must be set.")
""" if self.max_cpu_prompt_adapters is None:
image_token_str = tokenizer.decode(self.image_token_id) self.max_cpu_prompt_adapters = self.max_prompt_adapters
return image_token_str * self.image_feature_size, image_token_str
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects. def verify_with_model_config(self, model_config: ModelConfig):
""" if self.prompt_adapter_dtype in (None, "auto"):
result: Dict[str, Any] = {} self.prompt_adapter_dtype = model_config.dtype
for f in fields(self): elif isinstance(self.prompt_adapter_dtype, str):
value = getattr(self, f.name) self.prompt_adapter_dtype = getattr(torch,
if isinstance(value, enum.Enum): self.prompt_adapter_dtype)
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
result["disable_image_processor"] = self.image_processor is None
return result @dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
multimodal models."""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
...@@ -1194,10 +1342,16 @@ def _get_and_verify_dtype( ...@@ -1194,10 +1342,16 @@ def _get_and_verify_dtype(
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "auto": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 if config.model_type == "gemma2":
# models. logger.info(
logger.info("Casting torch.float32 to torch.float16.") "For Gemma 2, we downcast float32 to bfloat16 instead "
torch_dtype = torch.float16 "of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
else: else:
...@@ -1282,7 +1436,10 @@ def _get_and_verify_max_len( ...@@ -1282,7 +1436,10 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su": # The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] != "su" \
and rope_scaling["type"] != "longrope":
if disable_sliding_window: if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling # TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed. # with sliding window to see if this case should be allowed.
...@@ -1357,6 +1514,17 @@ class DecodingConfig: ...@@ -1357,6 +1514,17 @@ class DecodingConfig:
f"must be one of {valid_guided_backends}") f"must be one of {valid_guided_backends}")
@dataclass
class ObservabilityConfig:
"""Configuration for observability."""
otlp_traces_endpoint: Optional[str] = None
def __post_init__(self):
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
raise ValueError("OpenTelemetry packages must be installed before "
"configuring 'otlp_traces_endpoint'")
@dataclass(frozen=True) @dataclass(frozen=True)
class EngineConfig: class EngineConfig:
"""Dataclass which contains all engine-related configuration. This """Dataclass which contains all engine-related configuration. This
...@@ -1370,9 +1538,11 @@ class EngineConfig: ...@@ -1370,9 +1538,11 @@ class EngineConfig:
device_config: DeviceConfig device_config: DeviceConfig
load_config: LoadConfig load_config: LoadConfig
lora_config: Optional[LoRAConfig] lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig] multimodal_config: Optional[MultiModalConfig]
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig] decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
...@@ -1384,6 +1554,9 @@ class EngineConfig: ...@@ -1384,6 +1554,9 @@ class EngineConfig:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self): def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs. """Return the configs as a dictionary, for use in **kwargs.
......
from typing import List, Optional from typing import List, Optional
from vllm.core.block.common import BlockList
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list from vllm.utils import Device, cdiv, chunk_list
...@@ -47,12 +48,10 @@ class BlockTable: ...@@ -47,12 +48,10 @@ class BlockTable:
self._allocator = block_allocator self._allocator = block_allocator
if _blocks is None: if _blocks is None:
_blocks = [] _blocks = []
self._blocks: List[Block] = _blocks self._blocks: BlockList = BlockList(_blocks)
self._max_block_sliding_window = max_block_sliding_window self._max_block_sliding_window = max_block_sliding_window
# Use helper method instead of directly calculating, as blocks self._num_full_slots = self._get_num_token_ids()
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
@staticmethod @staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
...@@ -88,11 +87,18 @@ class BlockTable: ...@@ -88,11 +87,18 @@ class BlockTable:
""" """
assert not self._is_allocated assert not self._is_allocated
assert token_ids assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None, blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids, token_ids=token_ids,
device=device) device=device)
self.update(blocks)
self._num_full_slots = len(token_ids) self._num_full_slots = len(token_ids)
def update(self, blocks: List[Block]) -> None:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self._blocks.update(blocks)
def append_token_ids(self, def append_token_ids(self,
token_ids: List[int], token_ids: List[int],
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
...@@ -140,11 +146,11 @@ class BlockTable: ...@@ -140,11 +146,11 @@ class BlockTable:
num_lookahead_slots) num_lookahead_slots)
# Update the blocks with the new tokens # Update the blocks with the new tokens
blocks = self._blocks[self._num_full_slots // self._block_size:] first_block_idx = self._num_full_slots // self._block_size
token_blocks = self._chunk_token_blocks_for_append(token_ids) token_blocks = self._chunk_token_blocks_for_append(token_ids)
for block, token_block in zip(blocks, token_blocks): for i, token_block in enumerate(token_blocks):
block.append_token_ids(token_block) self._blocks.append_token_ids(first_block_idx + i, token_block)
self._num_full_slots += len(token_ids) self._num_full_slots += len(token_ids)
...@@ -174,8 +180,8 @@ class BlockTable: ...@@ -174,8 +180,8 @@ class BlockTable:
for _ in range(blocks_to_allocate): for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0 assert len(self._blocks) > 0
self._blocks.append( self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1], self._allocator.allocate_mutable_block(
device=device)) prev_block=self._blocks[-1], device=device))
def fork(self) -> "BlockTable": def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the """Creates a new BlockTable instance with a copy of the blocks from the
...@@ -209,12 +215,12 @@ class BlockTable: ...@@ -209,12 +215,12 @@ class BlockTable:
is set to `None`. is set to `None`.
""" """
assert self._is_allocated assert self._is_allocated
for block in self._blocks: for block in self.blocks:
self._allocator.free(block) self._allocator.free(block)
self._blocks = [] self._blocks.reset()
@property @property
def physical_block_ids(self) -> List[Optional[int]]: def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the """Returns a list of physical block indices for the blocks in the
BlockTable. BlockTable.
...@@ -228,7 +234,7 @@ class BlockTable: ...@@ -228,7 +234,7 @@ class BlockTable:
BlockTable. BlockTable.
""" """
assert self._is_allocated assert self._is_allocated
return [block.block_id for block in self._blocks] return self._blocks.ids()
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence. """Get the number of "unseen" tokens in the sequence.
...@@ -252,18 +258,32 @@ class BlockTable: ...@@ -252,18 +258,32 @@ class BlockTable:
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Device) -> List[Block]: device: Device) -> List[Block]:
blocks = [] blocks: List[Block] = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size: block_token_ids = []
# If the block is full, create an immutable block. tail_token_ids = []
prev_block = self._allocator.allocate_immutable( for cur_token_ids in chunk_list(token_ids, self._block_size):
prev_block, token_ids=block_token_ids, device=device) if len(cur_token_ids) == self._block_size:
block_token_ids.append(cur_token_ids)
else: else:
# Else, partially fill a mutable block with token ids. tail_token_ids.append(cur_token_ids)
prev_block = self._allocator.allocate_mutable(
prev_block=prev_block, device=device) if block_token_ids:
prev_block.append_token_ids(block_token_ids) blocks.extend(
blocks.append(prev_block) self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids,
device=device))
prev_block = blocks[-1]
if tail_token_ids:
assert len(tail_token_ids) == 1
cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device)
block.append_token_ids(cur_token_ids)
blocks.append(block)
return blocks return blocks
...@@ -274,18 +294,25 @@ class BlockTable: ...@@ -274,18 +294,25 @@ class BlockTable:
if not self._is_allocated: if not self._is_allocated:
return token_ids return token_ids
for block in self._blocks: for block in self.blocks:
token_ids.extend(block.token_ids) token_ids.extend(block.token_ids)
return token_ids return token_ids
def _get_num_token_ids(self) -> int:
res = 0
for block in self.blocks:
res += len(block.token_ids)
return res
@property @property
def _is_allocated(self) -> bool: def _is_allocated(self) -> bool:
return len(self._blocks) > 0 return len(self._blocks) > 0
@property @property
def blocks(self) -> Optional[List[Block]]: def blocks(self) -> List[Block]:
return self._blocks return self._blocks.list()
@property @property
def _num_empty_slots(self) -> int: def _num_empty_slots(self) -> int:
......
from typing import Dict, Iterable, List, Optional, Protocol, Tuple from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
...@@ -95,64 +96,40 @@ class CopyOnWriteTracker: ...@@ -95,64 +96,40 @@ class CopyOnWriteTracker:
The CopyOnWriteTracker class maintains a mapping of source block indices to The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference conjunction with a RefCounter.
counting and block allocation.
Args: Args:
refcounter (RefCounter): The reference counter used to track block refcounter (RefCounter): The reference counter used to track block
reference counts. reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
""" """
def __init__( def __init__(self, refcounter: RefCounterProtocol):
self,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] self._copy_on_writes: List[Tuple[BlockId, BlockId]] = []
self._refcounter = refcounter self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns: def is_appendable(self, block: Block) -> bool:
Optional[BlockId]: The block index of the new block if a copy-on """Checks if the block is shared or not. If shared, then it cannot
-write operation was performed, or the original block index if be appended and needs to be duplicated via copy-on-write
no copy-on-write was necessary.
""" """
block_id = block.block_id block_id = block.block_id
if block_id is None: if block_id is None:
return block_id return True
refcount = self._refcounter.get(block_id) refcount = self._refcounter.get(block_id)
assert refcount != 0 return refcount <= 1
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
# Track src/dst copy. def record_cow(self, src_block_id: Optional[BlockId],
assert src_block_id is not None trg_block_id: Optional[BlockId]) -> None:
assert block_id is not None """Records a copy-on-write operation from source to target block id
self._copy_on_writes.append((src_block_id, block_id)) Args:
src_block_id (BlockId): The source block id from which to copy
return block_id the data
trg_block_id (BlockId): The target block id to which the data
is copied
"""
assert src_block_id is not None
assert trg_block_id is not None
self._copy_on_writes.append((src_block_id, trg_block_id))
def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: def clear_cows(self) -> List[Tuple[BlockId, BlockId]]:
"""Clears the copy-on-write tracking information and returns the current """Clears the copy-on-write tracking information and returns the current
...@@ -172,6 +149,139 @@ class CopyOnWriteTracker: ...@@ -172,6 +149,139 @@ class CopyOnWriteTracker:
return cows return cows
class BlockPool:
"""Used to pre-allocate block objects, in order to avoid excessive python
object allocations/deallocations.
The pool starts from "pool_size" objects and will increase to more objects
if necessary
Note that multiple block objects may point to the same physical block id,
which is why this pool is needed, so that it will be easier to support
prefix caching and more complicated sharing of physical blocks.
"""
def __init__(self, block_size: int, create_block: Block.Factory,
allocator: BlockAllocator, pool_size: int):
self._block_size = block_size
self._create_block = create_block
self._allocator = allocator
self._pool_size = pool_size
assert self._pool_size >= 0
self._free_ids: Deque[int] = deque(range(self._pool_size))
self._pool = []
for i in range(self._pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def increase_pool(self):
"""Doubles the internal pool size
"""
cur_pool_size = self._pool_size
new_pool_size = cur_pool_size * 2
self._pool_size = new_pool_size
self._free_ids += deque(range(cur_pool_size, new_pool_size))
for i in range(cur_pool_size, new_pool_size):
self._pool.append(
self._create_block(prev_block=None,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
def init_block(self, prev_block: Optional[Block], token_ids: List[int],
block_size: int, physical_block_id: Optional[int]) -> Block:
if len(self._free_ids) == 0:
self.increase_pool()
assert len(self._free_ids) > 0
pool_id = self._free_ids.popleft()
block = self._pool[pool_id]
block.__init__( # type: ignore[misc]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id)
block.pool_id = pool_id # type: ignore[attr-defined]
return block
def free_block(self, block: Block) -> None:
self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined]
class BlockList:
"""This class is an optimization to allow fast-access to physical
block ids. It maintains a block id list that is updated with the
block list and this avoids the need to reconstruct the block id
list on every iteration of the block manager
"""
def __init__(self, blocks: List[Block]):
self._blocks: List[Block] = []
self._block_ids: List[int] = []
self.update(blocks)
def _add_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_ids.append(block_id)
def _update_block_id(self, block_index: int,
new_block_id: Optional[BlockId]) -> None:
assert new_block_id is not None
self._block_ids[block_index] = new_block_id
def update(self, blocks: List[Block]):
self._blocks = blocks
# Cache block ids for fast query
self._block_ids = []
for block in self._blocks:
self._add_block_id(block.block_id)
def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
block = self._blocks[block_index]
prev_block_id = block.block_id
block.append_token_ids(token_ids)
# CoW or promotion may update the internal block_id
if prev_block_id != block.block_id:
self._update_block_id(block_index, block.block_id)
def append(self, new_block: Block):
self._blocks.append(new_block)
self._add_block_id(new_block.block_id)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, block_index: int) -> Block:
return self._blocks[block_index]
def __setitem__(self, block_index: int, new_block: Block) -> None:
self._blocks[block_index] = new_block
self._update_block_id(block_index, new_block.block_id)
def reset(self):
self._blocks = []
self._block_ids = []
def list(self) -> List[Block]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids
def get_all_blocks_recursively(last_block: Block) -> List[Block]: def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block. """Retrieves all the blocks in a sequence starting from the last block.
......
...@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def allocate_or_get_null_block(self) -> Block: def allocate_or_get_null_block(self) -> Block:
if self._null_block is None: if self._null_block is None:
self._null_block = NullBlock( self._null_block = NullBlock(
self.allocate_mutable(None, Device.GPU)) self.allocate_mutable_block(None, Device.GPU))
return self._null_block return self._null_block
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block: device: Device) -> Block:
"""Allocates a new mutable block on the specified device. """Allocates a new mutable block on the specified device.
Args: Args:
...@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
return self._allocators[device].allocate_mutable(prev_block) return self._allocators[device].allocate_mutable_block(prev_block)
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable_blocks(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block: block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids)
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the """Allocates a new immutable block with the provided token IDs on the
specified device. specified device.
...@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided Block: The newly allocated immutable block containing the provided
token IDs. token IDs.
""" """
return self._allocators[device].allocate_immutable( return self._allocators[device].allocate_immutable_block(
prev_block, token_ids) prev_block, token_ids)
def free(self, block: Block) -> None: def free(self, block: Block) -> None:
...@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_id = block.block_id block_id = block.block_id
assert block_id is not None assert block_id is not None
allocator = self._block_ids_to_allocator[block_id] allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block) allocator.free(block)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying """Creates a new sequence of blocks that shares the same underlying
...@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
""" """
return self._allocators[device].get_physical_block_id(absolute_id) return self._allocators[device].get_physical_block_id(absolute_id)
def swap(self, blocks: List[Block], source_device: Device, def swap(self, blocks: List[Block], src_device: Device,
dest_device: Device) -> Dict[int, int]: dst_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device """Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each them to the accumulated `self._swap_mapping` for each
...@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args: Args:
blocks: List of blocks to be swapped. blocks: List of blocks to be swapped.
source_device (Device): Device to swap the 'blocks' from. src_device (Device): Device to swap the 'blocks' from.
dest_device (Device): Device to swap the 'blocks' to. dst_device (Device): Device to swap the 'blocks' to.
Returns: Returns:
Dict[int, int]: Swap mapping from source_device Dict[int, int]: Swap mapping from source_device
on to dest_device. on to dest_device.
""" """
source_block_ids = [block.block_id for block in blocks] src_block_ids = [block.block_id for block in blocks]
self._allocators[source_device].swap_out(blocks) self._allocators[src_device].swap_out(blocks)
self._allocators[dest_device].swap_in(blocks) self._allocators[dst_device].swap_in(blocks)
dest_block_ids = [block.block_id for block in blocks] dst_block_ids = [block.block_id for block in blocks]
current_swap_mapping: Dict[int, int] = {} current_swap_mapping: Dict[int, int] = {}
for src, dest in zip(source_block_ids, dest_block_ids): for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
if src is not None and dest is not None: if src_block_id is not None and dst_block_id is not None:
self._swap_mapping[src] = dest self._swap_mapping[src_block_id] = dst_block_id
current_swap_mapping[src] = dest current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping return current_swap_mapping
def get_num_blocks_touched(self, def get_num_blocks_touched(self,
...@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids) return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_computed_block_ids(
prev_computed_block_ids, block_ids, skip_last_block_id)
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU. # Prefix caching only supported on GPU.
device = Device.GPU device = Device.GPU
return self._allocators[device].get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids(
seq_block_ids) computed_seq_block_ids)
@property @property
def all_block_ids(self) -> FrozenSet[int]: def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys()) return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
def get_and_reset_swaps(self) -> List[Tuple[int, int]]: def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs. """Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every Will be called after every swapping operations for now, and after every
...@@ -341,6 +364,11 @@ class NullBlock(Block): ...@@ -341,6 +364,11 @@ class NullBlock(Block):
def token_ids(self) -> List[BlockId]: def token_ids(self) -> List[BlockId]:
return self._proxy.token_ids return self._proxy.token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for null block")
@property @property
def num_empty_slots(self) -> BlockId: def num_empty_slots(self) -> BlockId:
return self._proxy.num_empty_slots return self._proxy.num_empty_slots
......
...@@ -28,6 +28,13 @@ class Block(ABC): ...@@ -28,6 +28,13 @@ class Block(ABC):
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
pass pass
@property
@abstractmethod
def num_tokens_total(self) -> int:
"""The number of tokens till the current block (inclusive)
"""
pass
@property @property
@abstractmethod @abstractmethod
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
...@@ -92,12 +99,18 @@ class Block(ABC): ...@@ -92,12 +99,18 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block: token_ids: List[int]) -> Block:
pass
@abstractmethod
def allocate_immutable_blocks(
self, prev_block: Optional[Block],
block_token_ids: List[List[int]]) -> List[Block]:
pass pass
@abstractmethod @abstractmethod
...@@ -146,13 +159,19 @@ class BlockAllocator(ABC): ...@@ -146,13 +159,19 @@ class BlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass pass
@abstractmethod @abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]: def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block""" """NOTE: This should not be used besides Block"""
pass pass
...@@ -174,13 +193,20 @@ class BlockAllocator(ABC): ...@@ -174,13 +193,20 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC): class DeviceAwareBlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block: device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable_blocks(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block: block_token_ids: List[List[int]],
device: Device) -> List[Block]:
pass pass
@abstractmethod @abstractmethod
...@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC): ...@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
pass pass
@abstractmethod @abstractmethod
...@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC): ...@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def swap(self, blocks: List[Block], source_device: Device, def swap(self, blocks: List[Block], src_device: Device,
dest_device: Device) -> Dict[int, int]: dst_device: Device) -> Dict[int, int]:
pass pass
@abstractmethod @abstractmethod
......
from typing import FrozenSet, Iterable, List, Optional, Set, Tuple from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator):
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
block_ids: Optional[Iterable[int]] = None, block_ids: Optional[Iterable[int]] = None,
block_pool: Optional[BlockPool] = None,
): ):
if block_ids is None: if block_ids is None:
block_ids = range(num_blocks) block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids) self._free_block_indices: Deque[BlockId] = deque(block_ids)
self._all_block_indices = frozenset(block_ids) self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter( self._refcounter = RefCounter(
all_block_indices=self._free_block_indices) all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker( self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(), refcounter=self._refcounter.as_readonly())
allocator=self,
) if block_pool is None:
extra_factor = 4
def allocate_immutable(self, # Pre-allocate "num_blocks * extra_factor" block objects.
prev_block: Optional[Block], # The "* extra_factor" is a buffer to allow more block objects
token_ids: List[int], # than physical blocks
device: Optional[Device] = None) -> Block: self._block_pool = BlockPool(self._block_size, create_block, self,
num_blocks * extra_factor)
else:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self._block_pool = block_pool
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to """Allocates a new immutable block with the given token IDs, linked to
the previous block. the previous block.
...@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated immutable block. Block: The newly allocated immutable block.
""" """
assert device is None assert device is None
block = self.allocate_mutable(prev_block=prev_block) block = self.allocate_mutable_block(prev_block=prev_block)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
return block return block
def allocate_mutable(self, def allocate_immutable_blocks(
prev_block: Optional[Block], self,
device: Optional[Device] = None) -> Block: prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
assert device is None
num_blocks = len(block_token_ids)
block_ids = []
for i in range(num_blocks):
block_ids.append(self._allocate_block_id())
blocks = []
for i in range(num_blocks):
prev_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block_token_ids[i],
block_size=self._block_size,
physical_block_id=block_ids[i])
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block. """Allocates a new mutable block, linked to the previous block.
Args: Args:
...@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is None assert device is None
block_id = self._allocate_new_block_id() block_id = self._allocate_block_id()
return self._create_block( block = self._block_pool.init_block(prev_block=prev_block,
prev_block=prev_block, token_ids=[],
token_ids=[], block_size=self._block_size,
block_id=block_id, physical_block_id=block_id)
block_size=self._block_size, return block
allocator=self,
) def _allocate_block_id(self) -> BlockId:
if not self._free_block_indices:
def free(self, block: Block) -> None: raise BlockAllocator.NoFreeBlocksError()
assert block.block_id is not None
self._free_block_id(block.block_id) block_id = self._free_block_indices.popleft()
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
block.block_id = None block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
# Release the block object
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying """Creates a new sequence of blocks that shares the same underlying
memory as the original sequence. memory as the original sequence.
...@@ -111,7 +165,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -111,7 +165,7 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
source_blocks = get_all_blocks_recursively(last_block) source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = [] forked_blocks: List[Block] = []
prev_block = None prev_block = None
for block in source_blocks: for block in source_blocks:
...@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
refcount = self._refcounter.incr(block.block_id) refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block" assert refcount != 1, "can't fork free'd block"
forked_blocks.append( forked_block = self._block_pool.init_block(
self._create_block( prev_block=prev_block,
prev_block=prev_block, token_ids=block.token_ids,
token_ids=block.token_ids, block_size=self._block_size,
block_id=block.block_id, physical_block_id=block.block_id)
block_size=self._block_size,
allocator=self, forked_blocks.append(forked_block)
))
prev_block = forked_blocks[-1] prev_block = forked_blocks[-1]
return forked_blocks return forked_blocks
...@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
def get_num_total_blocks(self) -> int: def get_num_total_blocks(self) -> int:
return len(self._all_block_indices) return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
def get_physical_block_id(self, absolute_id: int) -> int: def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator """Returns the zero-offset block id on certain block allocator
given the absolute block id. given the absolute block id.
...@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]: def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not """Performs a copy-on-write operation on the given block if it is not
appendable. appendable.
...@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write. block (Block): The block to check for copy-on-write.
Returns: Returns:
Optional[BlockId]: The block index of the new block if a copy-on BlockId: The block index of the new block if a copy-on-write
-write operation was performed, or the original block index if operation was performed, or the original block index if
no copy-on-write was necessary. no copy-on-write was necessary.
""" """
return self._cow_tracker.cow_block_if_not_appendable(block) src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it. """Returns the copy-on-write source->destination mapping and clears it.
...@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
pass pass
def get_computed_block_ids(self, prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool) -> List[int]:
"""No prefix caching here => return empty list
"""
return []
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill. """Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return Since the naive allocator does not support prefix caching, always return
...@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
return [] return []
def promote_to_immutable_block(self, block: Block) -> BlockId: def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError raise NotImplementedError("There is no promotion for naive blocks")
def get_num_blocks_touched(self, def get_num_blocks_touched(self,
blocks: List[Block], blocks: List[Block],
...@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
def swap_out(self, blocks: List[Block]) -> None: def swap_out(self, blocks: List[Block]) -> None:
for block in blocks: for block in blocks:
self.free(block) self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None: def swap_in(self, blocks: List[Block]) -> None:
for block in blocks: for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full: if block.is_full:
alloc = self.allocate_immutable(block.prev_block, tmp_block = self.allocate_immutable_block(
block.token_ids) prev_block=block.prev_block, token_ids=block.token_ids)
else: else:
alloc = self.allocate_mutable(block.prev_block) tmp_block = self.allocate_mutable_block(
alloc.append_token_ids(block.token_ids) prev_block=block.prev_block)
block.block_id = alloc.block_id tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
tmp_block.block_id = None
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class NaiveBlock(Block): class NaiveBlock(Block):
...@@ -315,11 +382,12 @@ class NaiveBlock(Block): ...@@ -315,11 +382,12 @@ class NaiveBlock(Block):
self._append_token_ids_no_cow(token_ids) self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator """Appends the given token IDs to the block and performs a
to perform a copy-on-write if necessary. copy-on-write if necessary.
Args: Args:
token_ids (List[int]): The token IDs to be appended to the block. token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
""" """
self._append_token_ids_no_cow(token_ids) self._append_token_ids_no_cow(token_ids)
...@@ -328,7 +396,16 @@ class NaiveBlock(Block): ...@@ -328,7 +396,16 @@ class NaiveBlock(Block):
self._cow_target)) self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids) """Appends the given token IDs to the block
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
if len(token_ids) == 0:
return
assert len(token_ids) <= self.num_empty_slots
self._token_ids.extend(token_ids) self._token_ids.extend(token_ids)
@property @property
...@@ -361,12 +438,17 @@ class NaiveBlock(Block): ...@@ -361,12 +438,17 @@ class NaiveBlock(Block):
@property @property
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids) return self._block_size - len(self.token_ids)
@property @property
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
return self._token_ids return self._token_ids
@property
def num_tokens_total(self) -> int:
raise NotImplementedError(
"num_tokens_total is not used for naive block")
@property @property
def block_size(self) -> int: def block_size(self) -> int:
return self._block_size return self._block_size
......
"""Token blocks.""" """Token blocks."""
from itertools import takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -19,6 +19,30 @@ PrefixHash = int ...@@ -19,6 +19,30 @@ PrefixHash = int
_DEFAULT_LAST_ACCESSED_TIME = -1 _DEFAULT_LAST_ACCESSED_TIME = -1
class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
"""
__slots__ = ("active", "last_accessed", "computed")
def reset(self):
self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self.computed: bool = False
def __init__(self):
self.active: bool = False
self.reset()
def enable(self):
assert not self.active
self.active = True
self.reset()
def disable(self):
assert self.active
self.active = False
self.reset()
class PrefixCachingBlockAllocator(BlockAllocator): class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching. """A block allocator that implements prefix caching.
...@@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_ids: Optional[Iterable[int]] = None, block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU, eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
): ):
if block_ids is None:
block_ids = range(num_blocks)
self._block_size = block_size
# A mapping of prefix hash to block index. All blocks which have a # A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {} self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of blockId to Block to track those cached blocks # Used to track status of each physical block id
self._blocks: Dict[BlockId, Block] = {} self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids:
self._block_tracker[block_id] = BlockTracker()
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
extra_factor = 4
self._block_pool = BlockPool(self._block_size, self._create_block,
self, num_blocks * extra_factor)
# An allocator for blocks that do not have prefix hashes. # An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator( self._hashless_allocator = NaiveBlockAllocator(
...@@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks=num_blocks, num_blocks=num_blocks,
block_size=block_size, block_size=block_size,
block_ids=block_ids, block_ids=block_ids,
block_pool=self._block_pool, # Share block pool here
) )
self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks # Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high. # if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy) self.evictor: Evictor = make_evictor(eviction_policy)
...@@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._refcounter = self._hashless_allocator.refcounter self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker( self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(), refcounter=self._refcounter.as_readonly())
allocator=self,
)
# Implements Block.Factory. # Implements Block.Factory.
def _create_block( def _create_block(
...@@ -90,14 +125,14 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -90,14 +125,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
token_ids=token_ids, token_ids=token_ids,
block_size=block_size, block_size=block_size,
block_id=block_id, block_id=block_id,
prefix_caching_allocator=allocator, allocator=allocator,
computed=computed, computed=computed,
) )
def allocate_immutable(self, def allocate_immutable_block(self,
prev_block: Optional[Block], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Optional[Device] = None) -> Block: device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached """Allocates an immutable block with the given token IDs, reusing cached
blocks if possible. blocks if possible.
...@@ -111,29 +146,41 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -111,29 +146,41 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
block = self._create_block( # First, try to create a block that points to cached data
prev_block=prev_block, block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids, token_ids=token_ids,
block_size=self._block_size, block_size=self._block_size,
allocator=self, physical_block_id=None)
)
assert block.content_hash is not None assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None) cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None: if cached_block_id is not None:
block.block_id = cached_block_id block.block_id = cached_block_id
self._incr_refcount_cached_block(block, block.block_id) self._incr_refcount_cached_block(block)
return block return block
self._block_pool.free_block(block)
block = self.allocate_mutable(prev_block) # No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
assert block.content_hash is not None
return block return block
def allocate_mutable(self, def allocate_immutable_blocks(
prev_block: Optional[Block], self,
device: Optional[Device] = None) -> Block: prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device] = None) -> List[Block]:
blocks = []
for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids,
device=device)
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will """Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks. evict unused cached blocks.
...@@ -147,113 +194,154 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -147,113 +194,154 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert device is None assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
try: block_id = self._allocate_block_id()
block = self._hashless_allocator.allocate_mutable( block = self._block_pool.init_block(prev_block=prev_block,
prev_block=prev_block) token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
assert not block.computed
assert block.content_hash is None
return block
assert block.block_id not in self._blocks def _incr_refcount_cached_block(self, block: Block) -> None:
assert block.block_id is not None # Set this block to be "computed" since it is pointing to a
self._blocks[block.block_id] = block # cached block id (which was already computed)
return block block.computed = True
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
# If the evictor has blocks available for eviction, evict a block block_id = block.block_id
# and return it. assert block_id is not None
if self.evictor.num_blocks > 0:
# here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
_block_id = self._cached_blocks[content_hash_to_evict] refcount = self._refcounter.incr(block_id)
assert self._refcounter.get(_block_id) == 0 if refcount == 1:
assert _block_id == block_id # In case a cached block was evicted, restore its tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._cached_blocks.pop(content_hash_to_evict) self._track_block_id(block_id, computed=True)
self._refcounter.incr(block_id) def _decr_refcount_cached_block(self, block: Block) -> None:
# Ensure this is immutable/cached block
assert block.content_hash is not None
# the block comes from evictor already contain computed result block_id = block.block_id
block = self._create_block( assert block_id is not None
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
computed=True,
)
assert block.content_hash is None
assert block.block_id not in self._blocks refcount = self._refcounter.decr(block_id)
assert block.block_id is not None if refcount > 0:
self._blocks[block.block_id] = block block.block_id = None
return block return
else:
assert refcount == 0
# No block available in hashless allocator, nor in unused cache blocks. # No longer used
raise BlockAllocator.NoFreeBlocksError() assert block.content_hash in self._cached_blocks
def _incr_refcount_cached_block(self, block: Block, # Add the cached block to the evictor
block_id: BlockId) -> None: # (This keeps the cached block around so it can be reused)
# now _incr_refcount_cached_block comes from two place self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
# allocate_immutable/promote_to_immutable_block where hit self._block_tracker[block_id].last_accessed)
# _cached_blocks hash key.
# In both cases, it means that already exists a already
# computed block which shared with block now
block.computed = True
refcount = self._refcounter.incr(block_id) # Stop tracking the block
self._untrack_block_id(block_id)
block.block_id = None
def _decr_refcount_hashless_block(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
# We may have a fork case where block is shared,
# in which case, we cannot remove it from tracking
refcount = self._refcounter.get(block_id)
if refcount == 1: if refcount == 1:
# if block get referred, then it shall not be in evictor self._untrack_block_id(block_id)
# and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
def free(self, block: Block) -> None: # Decrement refcount of the block_id, but do not free the block object
"""Decrement the refcount of the block. If the decremented refcount is # itself (will be handled by the caller)
zero, store the block in the freelist. self._hashless_allocator.free(block, keep_block_object=True)
If the block has a content hash (meaning it is immutable), then we will def _allocate_block_id(self) -> BlockId:
keep the block around in case future allocations require it. """First tries to allocate a block id from the hashless allocator,
and if there are no blocks, then tries to evict an unused cached block.
""" """
assert (block.block_id hashless_block_id = self._maybe_allocate_hashless_block_id()
is not None), "freeing unallocated block is undefined" if hashless_block_id is not None:
return hashless_block_id
self._free_block_id_for_block(block.block_id, block) evicted_block_id = self._maybe_allocate_evicted_block_id()
if evicted_block_id is not None:
return evicted_block_id
block.block_id = None # No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]:
try:
# Allocate mutable block and extract its block_id
block = self._hashless_allocator.allocate_mutable_block(
prev_block=None)
block_id = block.block_id
self._block_pool.free_block(block)
self._track_block_id(block_id, computed=False)
return block_id
except BlockAllocator.NoFreeBlocksError:
return None
def _free_block_id_for_block(self, block_id: BlockId, def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
block: Block) -> None: if self.evictor.num_blocks == 0:
assert isinstance(block, PrefixCachingBlock) return None
# if we comes from promote_to_immutable_block, it means that
# block.content_hash is never None.
# However we need to release the same content block, so that
# physical block could get reused.
if block.block_id != block_id or block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
assert block.block_id is not None
refcount = self._refcounter.get(block.block_id)
if refcount == 1:
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id) # Here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
# Sanity checks
assert content_hash_to_evict in self._cached_blocks
_block_id = self._cached_blocks[content_hash_to_evict]
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
# If no longer used, add the block to the evictor. self._cached_blocks.pop(content_hash_to_evict)
if refcount == 0:
assert block.content_hash in self._cached_blocks self._refcounter.incr(block_id)
assert block.block_id is not None self._track_block_id(block_id, computed=False)
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash, return block_id
block.num_tokens_total, block.last_accessed)
def _free_block_id(self, block: Block) -> None:
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
block_id = block.block_id
assert block_id is not None, "Freeing unallocated block is undefined"
if block.content_hash is not None:
# Immutable: This type of block is always cached, and we want to
# keep it in the evictor for future reuse
self._decr_refcount_cached_block(block)
else:
# Mutable: This type of block is not cached, so we release it
# directly to the hashless allocator
self._decr_refcount_hashless_block(block)
assert block.block_id is None
def free(self, block: Block, keep_block_object: bool = False) -> None:
"""Release the block (look at free_block_id(..) docs)
"""
# Release the physical block index
self._free_block_id(block)
# Release the block object to the pool
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying """Creates a new sequence of blocks that shares the same underlying
...@@ -268,20 +356,23 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -268,20 +356,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
""" """
source_blocks = get_all_blocks_recursively(last_block) source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = [] forked_blocks: List[Block] = []
prev_block = None prev_block = None
for block in source_blocks: for block in source_blocks:
refcount = self._refcounter.incr(block.block_id) block_id = block.block_id
assert refcount != 1, "can't fork free'd block" assert block_id is not None
forked_blocks.append( refcount = self._refcounter.incr(block_id)
self._create_block( assert refcount != 1, "can't fork free'd block_id = {}".format(
prev_block=prev_block, block_id)
token_ids=block.token_ids,
block_id=block.block_id, forked_block = self._block_pool.init_block(
block_size=self._block_size, prev_block=prev_block,
allocator=self, token_ids=block.token_ids,
)) block_size=self._block_size,
physical_block_id=block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1] prev_block = forked_blocks[-1]
return forked_blocks return forked_blocks
...@@ -326,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -326,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Note that if we already have a cached block with the same content, we Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached will replace the newly-promoted block's mapping with the existing cached
block. block id.
Args: Args:
block: The mutable block to be promoted. block: The mutable block to be promoted.
...@@ -335,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -335,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator):
BlockId: Either the original block index, or the block index of BlockId: Either the original block index, or the block index of
the previously cached block matching the same content. the previously cached block matching the same content.
""" """
# Ensure block can be promoted
assert block.content_hash is not None assert block.content_hash is not None
assert block.block_id is not None assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0 assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks: if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached
# (Note that this block is not computed yet =>
# Will be computed after free())
self._cached_blocks[block.content_hash] = block.block_id self._cached_blocks[block.content_hash] = block.block_id
else: return block.block_id
self._free_block_id_for_block(
self._cached_blocks[block.content_hash], block) # Reuse the cached content hash
self._incr_refcount_cached_block( self._decr_refcount_hashless_block(block)
block, self._cached_blocks[block.content_hash]) block.block_id = self._cached_blocks[block.content_hash]
return self._cached_blocks[block.content_hash] # Increment refcount of the cached block and (possibly) restore
# it from the evictor.
# Note that in this case, the block is marked as computed
self._incr_refcount_cached_block(block)
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: return block.block_id
def cow_block_if_not_appendable(self, block: Block) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not """Performs a copy-on-write operation on the given block if it is not
appendable. appendable.
...@@ -359,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -359,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write. block (Block): The block to check for copy-on-write.
Returns: Returns:
Optional[BlockId]: The block index of the new block if a copy-on BlockId: The block index of the new block if a copy-on-write
-write operation was performed, or the original block index if operation was performed, or the original block index if
no copy-on-write was necessary. no copy-on-write was necessary.
""" """
return self._cow_tracker.cow_block_if_not_appendable(block) src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it. """Returns the copy-on-write source->destination mapping and clears it.
...@@ -383,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -383,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
""" """
for block_id in block_ids: for block_id in block_ids:
if block_id in self._blocks: if self._block_tracker[block_id].active:
self._blocks[block_id].last_accessed = now self._block_tracker[block_id].last_accessed = now
elif block_id in self.evictor: elif block_id in self.evictor:
self.evictor.update(block_id, now) self.evictor.update(block_id, now)
else: else:
...@@ -392,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -392,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU") "Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.""" raise NotImplementedError("Marking as computed is incremental")
for block_id in block_ids: def _track_block_id(self, block_id: Optional[BlockId],
if block_id in self._blocks: computed: bool) -> None:
# only those full block is valid for prefix caching assert block_id is not None
if self._blocks[block_id].is_full: self._block_tracker[block_id].enable()
self._blocks[block_id].computed = True self._block_tracker[block_id].computed = computed
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which " def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
"is not belonged to GPU") assert block_id is not None
self._block_tracker[block_id].disable()
def block_is_computed(self, block_id: int) -> bool: def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks: if self._block_tracker[block_id].active:
return self._blocks[block_id].computed return self._block_tracker[block_id].computed
else: else:
return block_id in self.evictor return block_id in self.evictor
def get_computed_block_ids(self,
prev_computed_block_ids: List[int],
block_ids: List[int],
skip_last_block_id: bool = True) -> List[int]:
prev_prefix_size = len(prev_computed_block_ids)
cur_size = len(block_ids)
if skip_last_block_id:
cur_size -= 1
# Sanity checks
assert cur_size >= 0
assert prev_prefix_size <= cur_size
ret = prev_computed_block_ids
for i in range(prev_prefix_size, cur_size):
block_id = block_ids[i]
if self.block_is_computed(block_id):
ret.append(block_id)
return ret
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked Only those blocks that are immutable and already be marked
...@@ -421,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -421,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prompt is cached. This would cause erroneous behavior in model # prompt is cached. This would cause erroneous behavior in model
# runner. # runner.
ids_list = [
list(
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
# It returns a list of int although type annotation says list of string. # It returns a list of int although type annotation says list of string.
return commonprefix([ return commonprefix([
ids for ids in ids_list # type: ignore ids for ids in computed_seq_block_ids # type: ignore
if ids != [] if ids != []
]) ])
...@@ -470,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -470,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped out. blocks: List of blocks to be swapped out.
""" """
for block in blocks: for block in blocks:
self.free(block) self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None: def swap_in(self, blocks: List[Block]) -> None:
"""Execute the swap int actions. Change the block id from """Execute the swap in actions. Change the block id from
old allocator to current allocator for each block to finish old allocator to current allocator for each block to finish
the block table update. the block table update.
...@@ -481,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -481,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped in. blocks: List of blocks to be swapped in.
""" """
for block in blocks: for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full: if block.is_full:
alloc = self.allocate_immutable(block.prev_block, tmp_block = self.allocate_immutable_block(
block.token_ids) prev_block=block.prev_block, token_ids=block.token_ids)
else: else:
alloc = self.allocate_mutable(block.prev_block) tmp_block = self.allocate_mutable_block(
alloc.append_token_ids(block.token_ids) prev_block=block.prev_block)
block.block_id = alloc.block_id tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
class PrefixCachingBlock(Block): class PrefixCachingBlock(Block):
...@@ -504,7 +638,7 @@ class PrefixCachingBlock(Block): ...@@ -504,7 +638,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block. token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in block_size (int): The maximum number of token IDs that can be stored in
the block. the block.
prefix_caching_allocator (BlockAllocator): The prefix allocator (BlockAllocator): The prefix
caching block allocator associated with this block. caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index block_id (Optional[int], optional): The physical block index
of this block. Defaults to None. of this block. Defaults to None.
...@@ -515,31 +649,55 @@ class PrefixCachingBlock(Block): ...@@ -515,31 +649,55 @@ class PrefixCachingBlock(Block):
prev_block: Optional[Block], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
block_size: int, block_size: int,
prefix_caching_allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: bool = False, computed: bool = False,
): ):
assert isinstance(prefix_caching_allocator, assert isinstance(allocator, PrefixCachingBlockAllocator), (
PrefixCachingBlockAllocator), ( "Currently this class is only tested with "
"Currently this class is only tested with " "PrefixCachingBlockAllocator. Got instead allocator = {}".format(
"PrefixCachingBlockAllocator.") allocator))
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None self._cached_num_tokens_total: int = 0
self._prefix_caching_allocator = prefix_caching_allocator self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed self._computed = computed
self._block = NaiveBlock( # On the first time, we create the block object, and next we only
prev_block=prev_block, # reinitialize it
token_ids=token_ids, if hasattr(self, "_block"):
block_size=block_size, self._block.__init__( # type: ignore[has-type]
block_id=block_id, prev_block=prev_block,
allocator=prefix_caching_allocator, token_ids=token_ids,
_cow_target=self, block_size=block_size,
) block_id=block_id,
allocator=self._allocator)
else:
self._block = NaiveBlock(prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
self._update_num_tokens_total()
def _update_num_tokens_total(self):
"""Incrementally computes the number of tokens that there is
till the current block (included)
"""
res = 0
# Add all previous blocks
if self._prev_block is not None:
res += self._prev_block.num_tokens_total
# Add current block
res += len(self.token_ids)
self._cached_num_tokens_total = res
@property @property
def computed(self) -> bool: def computed(self) -> bool:
...@@ -561,22 +719,28 @@ class PrefixCachingBlock(Block): ...@@ -561,22 +719,28 @@ class PrefixCachingBlock(Block):
"""Appends the given token IDs to the block and registers the block as """Appends the given token IDs to the block and registers the block as
immutable if the block becomes full. immutable if the block becomes full.
Internally, the naive block handles CoW.
Args: Args:
token_ids (List[int]): The token IDs to be appended to the block. token_ids (List[int]): The token IDs to be appended to the block.
""" """
assert token_ids # Ensure this is mutable block (not promoted)
assert self.content_hash is None
assert not self.computed
if len(token_ids) == 0:
return
# naive block handles CoW. # Ensure there are input tokens
assert token_ids, "Got token_ids = {}".format(token_ids)
# Naive block handles CoW.
self._block.append_token_ids(token_ids) self._block.append_token_ids(token_ids)
self._update_num_tokens_total()
# If the content hash is present, then the block can be made immutable. # If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the # Register ourselves with the allocator, potentially replacing the
# physical block index. # physical block index.
if self.content_hash is not None: if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator. self.block_id = self._allocator.promote_to_immutable_block(self)
promote_to_immutable_block(self))
@property @property
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
...@@ -596,23 +760,6 @@ class PrefixCachingBlock(Block): ...@@ -596,23 +760,6 @@ class PrefixCachingBlock(Block):
@property @property
def num_tokens_total(self) -> int: def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total return self._cached_num_tokens_total
@property @property
...@@ -635,7 +782,6 @@ class PrefixCachingBlock(Block): ...@@ -635,7 +782,6 @@ class PrefixCachingBlock(Block):
For the content-based hash to be defined, the current block must be For the content-based hash to be defined, the current block must be
full. full.
""" """
# If the hash is already computed, return it. # If the hash is already computed, return it.
if self._cached_content_hash is not None: if self._cached_content_hash is not None:
return self._cached_content_hash return self._cached_content_hash
...@@ -685,7 +831,129 @@ class PrefixCachingBlock(Block): ...@@ -685,7 +831,129 @@ class PrefixCachingBlock(Block):
return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
def __init__(self, allocator):
self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
bool]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._cached_computed_seq_blocks
self._cached_computed_seq_blocks[seq_id] = ([], False)
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._cached_computed_seq_blocks
del self._cached_computed_seq_blocks[seq_id]
def get_cached_computed_blocks_and_update(
self, seq_id: int, block_ids: List[int]) -> List[int]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]
if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0
if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks
# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
return computed_block_ids
class LastAccessBlocksTracker:
"""Manages the last access time of the tracked sequences, in order to allow
an efficient update of allocator's block last access times
"""
def __init__(self, allocator):
self._allocator = allocator
self._seq_last_access: Dict[int, Optional[float]] = {}
def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._seq_last_access
self._seq_last_access[seq_id] = None
def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._seq_last_access
del self._seq_last_access[seq_id]
def update_last_access(self, seq_id: int, time: float) -> None:
assert seq_id in self._seq_last_access
self._seq_last_access[seq_id] = time
def update_seq_blocks_last_access(self, seq_id: int,
block_ids: List[int]) -> None:
assert seq_id in self._seq_last_access
ts = self._seq_last_access[seq_id]
if ts is None:
# No last access was recorded, no need to update.
return
self._allocator.mark_blocks_as_accessed(block_ids, ts)
def assert_prefix_caching_block_or_none(block: Optional[Block]): def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None: if block is None:
return return
assert isinstance(block, PrefixCachingBlock) assert isinstance(block,
PrefixCachingBlock), "Got block = {}".format(block)
...@@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {} self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int: def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \ return 0 if seq is None else seq.n_blocks
else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
...@@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ref_count: int, \ ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable: is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = seq.n_blocks
block_table: BlockTable = [] block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks): for logical_idx in range(num_prompt_blocks):
...@@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other # Compute a new hash for the block so that it can be shared by other
# Sequences # Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(seq.n_blocks - 1)
# if new_hash is already in the cached table, then free last_block # if new_hash is already in the cached table, then free last_block
# and return the cached version # and return the cached version
...@@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if not self.enable_caching: if not self.enable_caching:
return self.gpu_allocator.allocate() return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)): if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block( num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes # num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for # (e.g. in the hashing function, it is used to ask the sequence for
...@@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]: ) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token.""" """Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block # If we need to allocate a new physical block
if len(block_table) < len(logical_blocks): if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block # Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1 assert len(block_table) == n_blocks - 1
if (self.block_sliding_window if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window): and len(block_table) >= self.block_sliding_window):
...@@ -472,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -472,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block. # NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM. # Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy() self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused. # When using a sliding window, blocks will be eventually reused.
......
...@@ -7,6 +7,8 @@ from typing import Tuple ...@@ -7,6 +7,8 @@ from typing import Tuple
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
...@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
...@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table: BlockTable = self._allocate_sequence(seq) block_table: BlockTable = self._allocate_sequence(seq)
self.block_tables[seq.seq_id] = block_table self.block_tables[seq.seq_id] = block_table
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence. # Assign the block table for each sequence.
for seq in waiting_seqs[1:]: for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork() self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence # Allocate cross-attention block table for encoder sequence
# #
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
...@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return new_cows return new_cows
def free(self, seq: Sequence) -> None: def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables: seq_id = seq.seq_id
if seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet. # Already freed or haven't been scheduled yet.
return return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id] # Update seq block ids with the latest access time
self._last_access_blocks_tracker.update_seq_blocks_last_access(
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
# Untrack seq
self._last_access_blocks_tracker.remove_seq(seq_id)
self._computed_blocks_tracker.remove_seq(seq_id)
# Free table/blocks
self.block_tables[seq_id].free()
del self.block_tables[seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None: def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id request_id = seq_group.request_id
...@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
del self.cross_block_tables[request_id] del self.cross_block_tables[request_id]
def get_block_table(self, seq: Sequence) -> List[int]: def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore return block_ids # type: ignore
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
...@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return block_ids # type: ignore return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching: if self.enable_caching:
block_table = self.block_tables[seq.seq_id] # Record the latest access time for the sequence. The actual update
block_ids = [] # of the block ids is deferred to the sequence free(..) call, since
for block_id in block_table.physical_block_ids: # only during freeing of block ids, the blocks are actually added to
block_ids.append(block_id) # the evictor (which is when the most updated time is required)
self.block_allocator.mark_blocks_as_accessed( # (This avoids expensive calls to mark_blocks_as_accessed(..))
block_ids, # type: ignore self._last_access_blocks_tracker.update_last_access(
now) seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching, # The only need for mark block as computed is for prefix caching,
...@@ -285,17 +304,29 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -285,17 +304,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
This method determines which blocks can be safely skipped for all This method determines which blocks can be safely skipped for all
sequences in the sequence group. sequences in the sequence group.
""" """
seq_block_ids = [ computed_seq_block_ids = []
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs for seq in seqs:
] computed_seq_block_ids.append(
self._computed_blocks_tracker.
get_cached_computed_blocks_and_update(
seq.seq_id,
self.block_tables[seq.seq_id].physical_block_ids))
# NOTE(sang): This assumes seq_block_ids doesn't contain any None. # NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids( return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) # type: ignore computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork() self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus: num_lookahead_slots: int) -> AllocStatus:
"""Returns the AllocStatus for the given sequence_group """Returns the AllocStatus for the given sequence_group
...@@ -323,19 +354,31 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -323,19 +354,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from CPU List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU. to GPU.
""" """
blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED) physical_block_id_mapping = []
current_swap_mapping = self.block_allocator.swap( for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU) blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
block_number_mapping = { continue
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id): seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
self.block_allocator.get_physical_block_id(Device.GPU, src_device=Device.CPU,
gpu_block_id) dst_device=Device.GPU)
for cpu_block_id, gpu_block_id in current_swap_mapping.items()
} # Refresh the block ids of the table (post-swap)
# convert to list of tuples once here self.block_tables[seq.seq_id].update(blocks)
return list(block_number_mapping.items())
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id):
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id)
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool: def can_swap_out(self, seq_group: SequenceGroup) -> bool:
"""Returns whether we can swap out the given sequence_group """Returns whether we can swap out the given sequence_group
...@@ -355,7 +398,7 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -355,7 +398,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return True return True
return False return False
def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by """Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots. swapping out the given sequence_group with num_lookahead_slots.
...@@ -366,19 +409,31 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -366,19 +409,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU. GPU to CPU.
""" """
blocks = self._get_blocks_for_swap(sequence_group, physical_block_id_mapping = []
SequenceStatus.RUNNING) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
current_swap_mapping = self.block_allocator.swap( blocks = self.block_tables[seq.seq_id].blocks
blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU) if len(blocks) == 0:
block_number_mapping = { continue
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id): seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
self.block_allocator.get_physical_block_id(Device.CPU, src_device=Device.GPU,
cpu_block_id) dst_device=Device.CPU)
for gpu_block_id, cpu_block_id in current_swap_mapping.items()
} # Refresh the block ids of the table (post-swap)
# convert to list of tuples once here self.block_tables[seq.seq_id].update(blocks)
return list(block_number_mapping.items())
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id):
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id)
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def get_num_free_gpu_blocks(self) -> int: def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU) return self.block_allocator.get_num_free_blocks(Device.GPU)
......
...@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager ...@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
...@@ -50,8 +51,8 @@ class SchedulingBudget: ...@@ -50,8 +51,8 @@ class SchedulingBudget:
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
...@@ -65,28 +66,28 @@ class SchedulingBudget: ...@@ -65,28 +66,28 @@ class SchedulingBudget:
return self.token_budget - self.num_batched_tokens return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens: if req_id in self._request_ids_num_batched_tokens:
return return
self._requeset_ids_num_batched_tokens.add(req_id) self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens self._num_batched_tokens += num_batched_tokens
def subtract_num_batched_tokens(self, req_id: str, def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int): num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens: if req_id in self._request_ids_num_batched_tokens:
self._requeset_ids_num_batched_tokens.remove(req_id) self._request_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens self._num_batched_tokens -= num_batched_tokens
def add_num_seqs(self, req_id: str, num_curr_seqs: int): def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs: if req_id in self._request_ids_num_curr_seqs:
return return
self._requeset_ids_num_curr_seqs.add(req_id) self._request_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs self._num_curr_seqs += num_curr_seqs
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs: if req_id in self._request_ids_num_curr_seqs:
self._requeset_ids_num_curr_seqs.remove(req_id) self._request_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs self._num_curr_seqs -= num_curr_seqs
@property @property
...@@ -139,6 +140,8 @@ class SchedulerOutputs: ...@@ -139,6 +140,8 @@ class SchedulerOutputs:
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
...@@ -157,6 +160,14 @@ class SchedulerOutputs: ...@@ -157,6 +160,14 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None if g.seq_group.lora_request is not None
} }
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass @dataclass
class SchedulerRunningOutputs: class SchedulerRunningOutputs:
...@@ -256,6 +267,7 @@ class Scheduler: ...@@ -256,6 +267,7 @@ class Scheduler:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -273,11 +285,19 @@ class Scheduler: ...@@ -273,11 +285,19 @@ class Scheduler:
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version) version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window, sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching) enable_caching=self.cache_config.enable_prefix_caching)
...@@ -290,7 +310,10 @@ class Scheduler: ...@@ -290,7 +310,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out. # Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step # Time at previous scheduling step
self.prev_time = 0.0 self.prev_time = 0.0
# Did we schedule a prompt at previous step? # Did we schedule a prompt at previous step?
...@@ -364,6 +387,12 @@ class Scheduler: ...@@ -364,6 +387,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running( def _schedule_running(
self, self,
running_queue: deque, running_queue: deque,
...@@ -1006,6 +1035,7 @@ class Scheduler: ...@@ -1006,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
...@@ -1027,6 +1057,11 @@ class Scheduler: ...@@ -1027,6 +1057,11 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished()) if not seq_group.is_finished())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment