Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
......@@ -3,14 +3,14 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__
__version__ = "0.4.1"
__version__ = "0.4.2"
__all__ = [
"LLM",
......
......@@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_context_len: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables,
context_lens, block_size, max_context_len,
alibi_slopes, kv_cache_dtype, kv_scale)
num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, kv_scale)
def paged_attention_v2(
......@@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_context_len: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale)
......@@ -153,11 +153,49 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale
......@@ -184,6 +222,18 @@ def reshape_and_cache(
slot_mapping, kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
......
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
import torch
......@@ -15,7 +16,7 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
raise NotImplementedError
@staticmethod
......@@ -50,13 +51,17 @@ class AttentionBackend(ABC):
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]:
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
for field in fields(self) if field.name not in skip_fields
}
......
......@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
......@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
......@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......@@ -245,10 +245,11 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
......@@ -257,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
use_cuda_graph: bool = False
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The number of query/output heads
num_qo_heads: Optional[int] = None
# The number of key/value heads
num_kv_heads: Optional[int] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
# Block size of vllm
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
self.paged_kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
)
if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
return output.view(num_tokens, hidden_size)
"""Attention layer ROCm GPUs."""
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
......@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
......@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
class ROCmFlashAttentionImpl(AttentionImpl):
......@@ -156,8 +156,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
......@@ -248,41 +247,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn or self.use_naive_attn:
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
True,
self.scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
self.scale,
)
else:
out = self.attn_func(
q=query,
......@@ -290,13 +284,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
......@@ -307,10 +303,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
if decode_meta := attn_metadata.decode_metadata:
......@@ -320,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......@@ -337,13 +334,13 @@ def _naive_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
prompt_lens: List[int],
seq_lens: List[int],
scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
for _, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
......@@ -352,7 +349,7 @@ def _naive_attention(
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
start += seq_len
return output
......
......@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
seq_lens: Optional[List[int]]
def __post_init__(self):
# Set during the execution of the first attention op.
......@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
......@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window,
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prompt_lens)
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
......@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens,
attn_metadata.attn_bias):
end = start + prompt_len
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
......@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
prompt_lens: List[int],
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
......@@ -221,7 +221,7 @@ def _make_alibi_bias(
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
......@@ -229,14 +229,14 @@ def _make_alibi_bias(
def _make_sliding_window_bias(
prompt_lens: List[int],
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
for seq_len in seq_lens:
tensor = torch.full(
(1, prompt_len, prompt_len),
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
......
......@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
......@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
......@@ -242,10 +241,11 @@ class XFormersImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
......@@ -256,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......@@ -288,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
assert attn_metadata.seq_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
......@@ -309,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if attn_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.prompt_lens)
attn_metadata.seq_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
......@@ -317,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else:
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.prompt_lens)
attn_metadata.seq_lens)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
......@@ -342,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(original_query)
start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
......@@ -353,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len
start += seq_len
return output
......@@ -361,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
prompt_lens: List[int],
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
......@@ -375,16 +375,16 @@ def _make_alibi_bias(
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (prompt_len + 7) // 8 * 8
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
prompt_len,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
......
......@@ -47,3 +47,10 @@ class Attention(nn.Module):
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
return s
......@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# Maximum context length in the batch.
max_context_len: Optional[int]
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
......@@ -85,8 +84,8 @@ class PagedAttention:
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
......@@ -97,7 +96,7 @@ class PagedAttention:
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
......@@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_context_len <= 8192
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
......@@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
......@@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
......@@ -168,10 +167,11 @@ class PagedAttention:
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
......@@ -184,10 +184,11 @@ class PagedAttention:
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
prompt_lens_tensor,
seq_lens_tensor,
context_lens,
max_subquery_len,
max_query_len,
alibi_slopes,
sliding_window,
)
return output
......
......@@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]
q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0)
other=0.0) # [M,D]
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
other=0) # [N]
# [D,N]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
# [N,D]
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
......@@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
k = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0)
other=0.0) # [D,N]
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
m_ij = tl.max(qk, 1) # [M]
p = tl.exp(qk - m_ij[:, None]) # [M,N]
l_ij = tl.sum(p, 1) # [M]
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
m_i_new = tl.maximum(m_i, m_ij) # [M]
alpha = tl.exp(m_i - m_i_new) # [M]
beta = tl.exp(m_ij - m_i_new) # [M]
l_i_new = alpha * l_i + beta * l_ij # [M]
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
......@@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
v = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
other=0.0) # [N,D]
p = p.to(v.dtype)
acc += tl.dot(p, v)
......@@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
......@@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
......@@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None):
alibi_slopes=None,
sliding_window=None):
cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
......@@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = 2**((Lk - 1).bit_length())
Lk_padded = triton.next_power_of_2(Lk)
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
......@@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
num_warps=num_warps,
num_stages=1,
)
......
......@@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps=4,
),
],
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
......@@ -330,8 +330,8 @@ def attn_fwd(
philox_seed,
philox_offset_base,
encoded_softmax,
hq,
hk,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
......@@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
......@@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
return
is_mqa = hq != hk
if is_mqa: # noqa: SIM108
off_h_k = off_h_q % hk
else:
off_h_k = off_h_q
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
......@@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * hq + off_h_q) \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
......@@ -624,7 +622,7 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
......@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
hq=nheads_q,
hk=nheads_k,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
......
import enum
import os
from functools import lru_cache
from typing import Type
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
logger = init_logger(__name__)
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention backend.")
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
......@@ -43,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")
......@@ -62,12 +66,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
......@@ -75,11 +79,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
return _Backend.XFORMERS
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
......
import enum
import json
import os
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
......@@ -9,11 +8,14 @@ from packaging.version import Version
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
......@@ -21,10 +23,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
"False").lower() == "true"
_GB = 1 << 30
......@@ -33,6 +31,8 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
......@@ -65,9 +65,16 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
def __init__(
......@@ -86,8 +93,10 @@ class ModelConfig:
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
......@@ -101,6 +110,11 @@ class ModelConfig:
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init
......@@ -110,6 +124,8 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.served_model_name = get_served_model_name(model,
served_model_name)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_quantization()
......@@ -138,14 +154,34 @@ class ModelConfig:
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
or quant_cfg.get("is_marlin_format", False))
# Use marlin if the GPTQ model is serialized in marlin format.
if quant_method == "gptq" and is_format_marlin:
logger.info("The model is serialized in Marlin format. "
# Check which LinearMethod the GPTQ model should use.
if quant_method == "gptq":
# If serialized in Marlin format, use MarlinLinearMethod.
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
if is_format_marlin:
logger.info("The model is serialized in Marlin format. "
"Using Marlin kernel.")
quant_method = "marlin"
if self.quantization == "gptq":
self.quantization = quant_method
# If convertible to Marlin format, use GPTQMarlinLinearMethod
# unless the user explicitly specified GPTQLinearMethod.
elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
if self.quantization == "gptq":
logger.warning(
"The model is convertible to Marlin format, but "
"you specified quantization=gptq. Use "
"quantization=marlin for faster inference.")
else:
logger.info(
"The model is convertible to Marlin format. "
"Using Marlin kernel.")
quant_method = "marlin"
if self.quantization == "gptq":
self.quantization = quant_method
quant_method = "gptq_marlin"
if self.quantization == "marlin":
self.quantization = quant_method
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
......@@ -165,17 +201,17 @@ class ModelConfig:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if self.quantization != "marlin":
if (self.quantization not in ["marlin", "gptq_marlin"]):
logger.warning(
f"{self.quantization} quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
def verify_with_parallel_config(
self,
......@@ -271,6 +307,11 @@ class ModelConfig:
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
......@@ -330,7 +371,8 @@ class CacheConfig:
elif self.cache_dtype == "fp8":
if not is_hip():
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
if nvcc_cuda_version is not None \
and nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is"
"lower than 11.8.")
......@@ -360,7 +402,7 @@ class CacheConfig:
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. " + msg)
logger.warning("Possibly too large swap space. %s", msg)
@dataclass
......@@ -574,8 +616,9 @@ class SchedulerConfig:
self.max_num_batched_tokens = max_num_batched_tokens
else:
if enable_chunked_prefill:
# For chunked prefill, choose the well-tuned batch size.
self.max_num_batched_tokens = 768
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
......@@ -658,6 +701,8 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
......@@ -684,6 +729,10 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
......@@ -718,39 +767,57 @@ class SpeculativeConfig:
draft_code_revision = None
draft_quantization = None
draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
if speculative_model == "[ngram]":
assert (ngram_prompt_lookup_max is not None
and ngram_prompt_lookup_max > 0)
if ngram_prompt_lookup_min is None:
ngram_prompt_lookup_min = 0
else:
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config = target_model_config
draft_parallel_config = target_parallel_config
else:
ngram_prompt_lookup_max = 0
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
)
@staticmethod
......@@ -818,6 +885,8 @@ class SpeculativeConfig:
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
ngram_prompt_lookup_max: int,
ngram_prompt_lookup_min: int,
):
"""Create a SpeculativeConfig object.
......@@ -830,6 +899,8 @@ class SpeculativeConfig:
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
self._verify_args()
......@@ -853,7 +924,10 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def __repr__(self) -> str:
draft_model = self.draft_model_config.model
if self.ngram_prompt_lookup_max > 0:
draft_model = "[ngram]"
else:
draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
......@@ -862,6 +936,7 @@ class SpeculativeConfig:
class LoRAConfig:
max_lora_rank: int
max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
......@@ -898,8 +973,8 @@ class LoRAConfig:
"awq", "gptq"
]:
# TODO support marlin and squeezellm
logger.warning(f"{model_config.quantization} quantization is not "
"tested with LoRA yet.")
logger.warning("%s quantization is not tested with LoRA yet.",
model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
......@@ -1008,7 +1083,7 @@ def _get_and_verify_dtype(
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
......@@ -1051,12 +1126,12 @@ def _get_and_verify_max_len(
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
"%d. Assuming the model's maximum length is %d.", possible_keys,
default_max_len)
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if rope_scaling is not None and rope_scaling["type"] != "su":
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
......@@ -1084,6 +1159,22 @@ def _get_and_verify_max_len(
return int(max_model_len)
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`.
"""
if not served_model_name:
return model
if isinstance(served_model_name, list):
return served_model_name[0]
return served_model_name
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""
......
......@@ -40,7 +40,9 @@ class BlockTable:
):
self._block_size = block_size
self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
......@@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None
assert len(self._blocks) > 0
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
......@@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
......@@ -159,6 +162,7 @@ class BlockTable:
the current instance.
"""
assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
......@@ -177,10 +181,10 @@ class BlockTable:
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = None
self._blocks = []
@property
def physical_block_ids(self) -> List[int]:
def physical_block_ids(self) -> List[Optional[int]]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
......@@ -235,7 +239,7 @@ class BlockTable:
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids = []
token_ids: List[int] = []
if not self._is_allocated:
return token_ids
......@@ -247,7 +251,7 @@ class BlockTable:
@property
def _is_allocated(self) -> bool:
return self._blocks is not None
return len(self._blocks) > 0
@property
def _num_empty_slots(self) -> int:
......
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Protocol
from vllm.core.block.interfaces import Block, BlockAllocator
......@@ -7,7 +7,19 @@ BlockId = int
RefCount = int
class RefCounter:
class RefCounterProtocol(Protocol):
def incr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def decr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def get(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
class RefCounter(RefCounterProtocol):
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
......@@ -54,7 +66,7 @@ class RefCounter:
return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter:
class ReadOnlyRefCounter(RefCounterProtocol):
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
......@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def __init__(
self,
refcounter: RefCounter,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
......
from typing import Dict, List, Optional
from typing import Dict, FrozenSet, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator,
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
......@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
......@@ -105,7 +105,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.GPU: gpu_block_allocator,
}
self._block_ids_to_allocator = {}
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
......@@ -149,7 +149,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
block (Block): The block to be freed.
"""
allocator = self._block_ids_to_allocator[block.block_id]
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
......@@ -163,7 +165,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator = self._block_ids_to_allocator[last_block.block_id]
block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int:
......@@ -171,13 +175,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
device (Device): The device for which to query the number of free
blocks.
blocks. AssertionError is raised if None is passed.
Returns:
int: The number of free blocks available on the specified device.
"""
return self._allocators[device].get_num_free_blocks()
def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
......@@ -190,10 +197,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed()
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
......@@ -202,5 +217,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
def all_block_ids(self) -> frozenset[int]:
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
......@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
BlockId = int
class Block(ABC):
......@@ -15,6 +17,12 @@ class Block(ABC):
def block_id(self) -> Optional[int]:
pass
@block_id.setter
@abstractmethod
def block_id(self, value: Optional[int]) -> None:
"""NOTE: Do not use this API outside Block."""
self._block_id = value
@property
@abstractmethod
def token_ids(self) -> List[int]:
......@@ -35,6 +43,27 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]:
pass
@property
@abstractmethod
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
@abstractmethod
def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError
@property
@abstractmethod
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
@abstractmethod
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
class Factory(Protocol):
@abstractmethod
......@@ -48,6 +77,17 @@ class Block(ABC):
) -> "Block":
pass
@property
@abstractmethod
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return None
class BlockAllocator(ABC):
......@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
token_ids: List[int]) -> Block:
pass
@abstractmethod
......@@ -69,7 +109,11 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
def get_num_total_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@property
......@@ -82,7 +126,12 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
......@@ -90,14 +139,25 @@ class BlockAllocator(ABC):
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
class NoFreeBlocksError(ValueError):
pass
class DeviceAwareBlockAllocator(BlockAllocator):
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
......@@ -108,3 +168,38 @@ class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
pass
@abstractmethod
def get_num_total_blocks(self, device: Device) -> int:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
from typing import Dict, Iterable, List, Optional, Set
from typing import Dict, FrozenSet, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
BlockId = int
Refcount = int
......@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator=self,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
......@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
......@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
assert device is None
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
......@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
)
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
......@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for block in source_blocks:
# Increment refcount for each block.
assert block.block_id is not None
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
......@@ -129,6 +136,9 @@ class NaiveBlockAllocator(BlockAllocator):
def get_num_free_blocks(self) -> int:
return len(self._free_block_indices)
def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
......@@ -148,7 +158,7 @@ class NaiveBlockAllocator(BlockAllocator):
return self._refcounter
@property
def all_block_ids(self):
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
......@@ -174,7 +184,16 @@ class NaiveBlockAllocator(BlockAllocator):
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
......@@ -191,6 +210,9 @@ class NaiveBlockAllocator(BlockAllocator):
"""
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
......@@ -215,13 +237,13 @@ class NaiveBlock(Block):
"""
def __init__(self,
prev_block: Block,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
self._token_ids = []
self._token_ids: List[int] = []
self._block_size = block_size
self._prev_block = prev_block
self._block_id = block_id
......@@ -247,6 +269,22 @@ class NaiveBlock(Block):
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)
@property
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
def computed(self, value) -> None:
raise NotImplementedError
@property
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
@property
def block_id(self) -> Optional[int]:
return self._block_id
......@@ -267,9 +305,14 @@ class NaiveBlock(Block):
def token_ids(self) -> List[int]:
return self._token_ids
@property
def block_size(self) -> int:
return self._block_size
@property
def prev_block(self) -> Optional["Block"]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
return None
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, Iterable, List, Optional
from typing import Dict, FrozenSet, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int
BlockId = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
class PrefixCachingBlockAllocator(BlockAllocator):
......@@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1.
"""
# TODO last access time / evictor integration
def __init__(
self,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
# of self._cached_blocks.
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of blockId to Block to track those cached blocks
self._blocks: Dict[BlockId, Block] = {}
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block,
create_block=self._create_block, # type: ignore
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
......@@ -54,6 +56,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
......@@ -72,6 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
) -> Block:
# Bind block to self.
allocator = self
......@@ -82,10 +89,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
computed=computed,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
......@@ -96,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns:
Block: The allocated immutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
......@@ -109,65 +120,95 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block.content_hash,
block.block_id)
self._incr_refcount_cached_block(block, block.block_id)
return block
block = self.allocate_mutable(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
# TODO computed bit
return block
def allocate_mutable(self, prev_block: Block) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Block: The allocated mutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
try:
return self._hashless_allocator.allocate_mutable(
block = self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
if self._unused_cached_blocks:
# TODO policy for selecting block to remove
content_hash_to_evict = next(iter(self._unused_cached_blocks))
# If the evictor has blocks available for eviction, evict a block
# and return it.
if self.evictor.num_blocks > 0:
block_id, content_hash_to_evict = self.evictor.evict()
# Here we may have scenario that several blocks have
# the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
# Clear content hash mapping; the block will be overwritten.
del self._cached_blocks[content_hash_to_evict]
_block_id = self._cached_blocks[content_hash_to_evict]
refcount = self._refcounter.get(_block_id)
if refcount == 1:
self._cached_blocks.pop(content_hash_to_evict)
assert _block_id == block_id
block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
refcount = self._refcounter.incr(block_id)
assert refcount == 1
self._refcounter.incr(block_id)
# the block comes from evictor already contain computed result
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
computed=True,
)
assert block.content_hash is None
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, content_hash: int,
def _incr_refcount_cached_block(self, block: Block,
block_id: BlockId) -> None:
# since block is already computed, mark it
block.computed = True
refcount = self._refcounter.incr(block_id)
if refcount == 1:
assert content_hash in self._unused_cached_blocks
del self._unused_cached_blocks[content_hash]
# if block get referred, then it shall not be in evictor
# and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
......@@ -180,6 +221,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
is not None), "freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block)
block.block_id = None
def _free_block_id_for_block(self, block_id: BlockId,
......@@ -187,15 +229,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert isinstance(block, PrefixCachingBlock)
if block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1:
assert block.block_id is not None
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the unused cached blocks.
# If no longer used, add the block to the evictor.
if refcount == 0:
assert block.content_hash not in self._unused_cached_blocks
assert block.content_hash in self._cached_blocks
self._unused_cached_blocks[block.content_hash] = block_id
assert block.block_id is not None
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
......@@ -228,18 +278,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return forked_blocks
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
return self._hashless_allocator.get_num_free_blocks() + len(
self._unused_cached_blocks)
# plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks
def get_num_total_blocks(self) -> int:
return self._hashless_allocator.get_num_total_blocks()
@property
def all_block_ids(self) -> frozenset[int]:
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self,
block: "PrefixCachingBlock") -> BlockId:
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
......@@ -249,7 +302,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.
Args:
block (PrefixCachingBlock): The mutable block to be promoted.
block: The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
......@@ -266,7 +319,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else:
self._free_block_id_for_block(block.block_id, block)
self._incr_refcount_cached_block(
block.content_hash, self._cached_blocks[block.content_hash])
block, self._cached_blocks[block.content_hash])
return self._cached_blocks[block.content_hash]
......@@ -293,29 +346,63 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for block_id in block_ids:
if block_id in self._blocks:
self._blocks[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
for block_id in block_ids:
if block_id in self._blocks:
# only those full block is valid for prefix caching
if self._blocks[block_id].is_full:
self._blocks[block_id].computed = True
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which "
"is not belonged to GPU")
def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks:
return self._blocks[block_id].computed
else:
return block_id in self.evictor
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
Only those blocks that are immutable and already be marked
compyted would be taken consideration.
"""
# TODO: Track computed blocks.
computed = lambda block_id: False
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
takewhile(lambda block_id: computed(block_id), seq[:-1])
for seq in seq_block_ids
list(
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
return commonprefix([ids for ids in ids_list if ids != []])
# It returns a list of int although type annotation says list of string.
return commonprefix([
ids for ids in ids_list # type: ignore
if ids != []
])
class PrefixCachingBlock(Block):
......@@ -332,7 +419,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
......@@ -340,17 +427,25 @@ class PrefixCachingBlock(Block):
def __init__(
self,
prev_block: Optional["PrefixCachingBlock"],
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator,
prefix_caching_allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._block = NaiveBlock(
prev_block=prev_block,
......@@ -361,6 +456,22 @@ class PrefixCachingBlock(Block):
_cow_target=self,
)
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
......@@ -398,6 +509,27 @@ class PrefixCachingBlock(Block):
def num_empty_slots(self) -> int:
return self._block.num_empty_slots
@property
def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total
@property
def block_size(self) -> int:
return self._block.block_size
......@@ -428,8 +560,10 @@ class PrefixCachingBlock(Block):
return None
is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else
self._prev_block.content_hash)
prev_block_hash = (
None if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment