Unverified Commit cf8cac8c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy][6/N] Fix all the core subdirectory typing (#4450)


Co-authored-by: default avatarCade Daniel <edacih@gmail.com>
parent 5e401bce
......@@ -33,6 +33,7 @@ jobs:
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
......@@ -42,9 +43,6 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
......@@ -95,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
......
......@@ -40,7 +40,9 @@ class BlockTable:
):
self._block_size = block_size
self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks
if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
......@@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None
assert len(self._blocks) > 0
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
......@@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
......@@ -159,6 +162,7 @@ class BlockTable:
the current instance.
"""
assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
......@@ -177,10 +181,10 @@ class BlockTable:
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = None
self._blocks = []
@property
def physical_block_ids(self) -> List[int]:
def physical_block_ids(self) -> List[Optional[int]]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
......@@ -235,7 +239,7 @@ class BlockTable:
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids = []
token_ids: List[int] = []
if not self._is_allocated:
return token_ids
......@@ -247,7 +251,7 @@ class BlockTable:
@property
def _is_allocated(self) -> bool:
return self._blocks is not None
return len(self._blocks) > 0
@property
def _num_empty_slots(self) -> int:
......
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Protocol
from vllm.core.block.interfaces import Block, BlockAllocator
......@@ -7,7 +7,19 @@ BlockId = int
RefCount = int
class RefCounter:
class RefCounterProtocol(Protocol):
def incr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def decr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def get(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
class RefCounter(RefCounterProtocol):
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
......@@ -54,7 +66,7 @@ class RefCounter:
return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter:
class ReadOnlyRefCounter(RefCounterProtocol):
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
......@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def __init__(
self,
refcounter: RefCounter,
refcounter: RefCounterProtocol,
allocator: BlockAllocator,
):
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
......
from typing import Dict, List, Optional
from typing import Dict, FrozenSet, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator,
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
......@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
......@@ -105,13 +105,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.GPU: gpu_block_allocator,
}
self._block_ids_to_allocator = {}
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
......@@ -122,10 +123,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
assert device is not None
return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
......@@ -140,6 +144,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided
token IDs.
"""
assert device is not None
return self._allocators[device].allocate_immutable(
prev_block, token_ids)
......@@ -149,7 +154,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
block (Block): The block to be freed.
"""
allocator = self._block_ids_to_allocator[block.block_id]
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
......@@ -163,19 +170,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator = self._block_ids_to_allocator[last_block.block_id]
block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int:
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks.
blocks. AssertionError is raised if None is passed.
Returns:
int: The number of free blocks available on the specified device.
"""
assert device is not None
return self._allocators[device].get_num_free_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
......@@ -210,5 +220,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
def all_block_ids(self) -> frozenset[int]:
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
......@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
BlockId = int
class Block(ABC):
......@@ -15,6 +17,12 @@ class Block(ABC):
def block_id(self) -> Optional[int]:
pass
@block_id.setter
@abstractmethod
def block_id(self, value: Optional[int]) -> None:
"""NOTE: Do not use this API outside Block."""
self._block_id = value
@property
@abstractmethod
def token_ids(self) -> List[int]:
......@@ -35,6 +43,27 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]:
pass
@property
@abstractmethod
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
@abstractmethod
def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError
@property
@abstractmethod
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
@abstractmethod
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
class Factory(Protocol):
@abstractmethod
......@@ -48,6 +77,17 @@ class Block(ABC):
) -> "Block":
pass
@property
@abstractmethod
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return None
class BlockAllocator(ABC):
......@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
token_ids: List[int]) -> Block:
pass
@abstractmethod
......@@ -69,7 +109,7 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
def get_num_free_blocks(self) -> int:
pass
@property
......@@ -82,11 +122,12 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def mark_blocks_as_accessed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
......@@ -94,21 +135,66 @@ class BlockAllocator(ABC):
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
class NoFreeBlocksError(ValueError):
pass
class DeviceAwareBlockAllocator(BlockAllocator):
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
pass
@abstractmethod
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@property
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
from typing import Dict, Iterable, List, Optional, Set
from typing import Dict, FrozenSet, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
BlockId = int
Refcount = int
......@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator=self,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
......@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
......@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
assert device is None
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
......@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
)
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
......@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for block in source_blocks:
# Increment refcount for each block.
assert block.block_id is not None
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
......@@ -126,7 +133,8 @@ class NaiveBlockAllocator(BlockAllocator):
return forked_blocks
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
return len(self._free_block_indices)
def _allocate_new_block_id(self) -> BlockId:
......@@ -148,7 +156,7 @@ class NaiveBlockAllocator(BlockAllocator):
return self._refcounter
@property
def all_block_ids(self):
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
......@@ -200,6 +208,9 @@ class NaiveBlockAllocator(BlockAllocator):
"""
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
......@@ -224,13 +235,13 @@ class NaiveBlock(Block):
"""
def __init__(self,
prev_block: Block,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
self._token_ids = []
self._token_ids: List[int] = []
self._block_size = block_size
self._prev_block = prev_block
self._block_id = block_id
......@@ -256,6 +267,22 @@ class NaiveBlock(Block):
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)
@property
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
def computed(self, value) -> None:
raise NotImplementedError
@property
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
@property
def block_id(self) -> Optional[int]:
return self._block_id
......@@ -276,9 +303,14 @@ class NaiveBlock(Block):
def token_ids(self) -> List[int]:
return self._token_ids
@property
def block_size(self) -> int:
return self._block_size
@property
def prev_block(self) -> Optional["Block"]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
return None
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, Iterable, List, Optional
from typing import Dict, FrozenSet, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int
BlockId = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
......@@ -38,7 +37,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
......@@ -49,7 +48,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block,
create_block=self._create_block, # type: ignore
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
......@@ -79,7 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: Optional[bool] = False,
computed: bool = False,
) -> Block:
# Bind block to self.
allocator = self
......@@ -93,8 +92,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
computed=computed,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
......@@ -105,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns:
Block: The allocated immutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
......@@ -127,16 +129,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return block
def allocate_mutable(self, prev_block: Block) -> Block:
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Block: The allocated mutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
try:
......@@ -144,6 +150,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
prev_block=prev_block)
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
except BlockAllocator.NoFreeBlocksError:
......@@ -183,6 +190,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert block.content_hash is None
assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block
return block
......@@ -225,6 +233,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1:
assert block.block_id is not None
del self._blocks[block.block_id]
return self._hashless_allocator.free(block)
......@@ -233,6 +242,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# If no longer used, add the block to the evictor.
if refcount == 0:
assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
......@@ -268,18 +278,18 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return forked_blocks
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks
# plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks
@property
def all_block_ids(self) -> frozenset[int]:
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self,
block: "PrefixCachingBlock") -> BlockId:
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
......@@ -289,7 +299,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.
Args:
block (PrefixCachingBlock): The mutable block to be promoted.
block: The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
......@@ -385,8 +395,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
]
res = commonprefix([ids for ids in ids_list if ids != []])
return res
# It returns a list of int although type annotation says list of string.
return commonprefix([
ids for ids in ids_list # type: ignore
if ids != []
])
class PrefixCachingBlock(Block):
......@@ -403,7 +416,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
......@@ -411,21 +424,25 @@ class PrefixCachingBlock(Block):
def __init__(
self,
prev_block: Optional["PrefixCachingBlock"],
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator,
prefix_caching_allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: Optional[bool] = False,
computed: bool = False,
):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
self.computed = computed
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._block = NaiveBlock(
prev_block=prev_block,
......@@ -436,6 +453,22 @@ class PrefixCachingBlock(Block):
_cow_target=self,
)
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
......@@ -483,7 +516,7 @@ class PrefixCachingBlock(Block):
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block = self
_block: Optional[Block] = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
......@@ -524,8 +557,10 @@ class PrefixCachingBlock(Block):
return None
is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else
self._prev_block.content_hash)
prev_block_hash = (
None if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
......
......@@ -190,7 +190,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed
......@@ -204,7 +204,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_ids = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(block_ids, now)
self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching,
......@@ -227,8 +229,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids)
seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
......
......@@ -32,15 +32,20 @@ class Evictor(ABC):
@abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def update(self, block_id: int, last_accessed: int):
def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata"""
pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
......@@ -55,7 +60,7 @@ class BlockMetaData():
"""
def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
last_accessed: float):
self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed
......@@ -96,12 +101,12 @@ class LRUEvictor(Evictor):
return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens,
last_accessed)
def update(self, block_id: int, last_accessed: int):
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment