Unverified Commit cf8cac8c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy][6/N] Fix all the core subdirectory typing (#4450)


Co-authored-by: default avatarCade Daniel <edacih@gmail.com>
parent 5e401bce
...@@ -33,6 +33,7 @@ jobs: ...@@ -33,6 +33,7 @@ jobs:
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
...@@ -42,9 +43,6 @@ jobs: ...@@ -42,9 +43,6 @@ jobs:
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
...@@ -95,7 +95,7 @@ echo 'vLLM yapf: Done' ...@@ -95,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run mypy # Run mypy
echo 'vLLM mypy:' echo 'vLLM mypy:'
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
......
...@@ -40,7 +40,9 @@ class BlockTable: ...@@ -40,7 +40,9 @@ class BlockTable:
): ):
self._block_size = block_size self._block_size = block_size
self._allocator = block_allocator self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
# Use helper method instead of directly calculating, as blocks # Use helper method instead of directly calculating, as blocks
# may not be allocated. # may not be allocated.
...@@ -104,7 +106,7 @@ class BlockTable: ...@@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended. token_ids (List[int]): The sequence of token IDs to be appended.
""" """
assert self._is_allocated assert self._is_allocated
assert self._blocks is not None assert len(self._blocks) > 0
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots)
...@@ -141,6 +143,7 @@ class BlockTable: ...@@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate): for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append( self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1], self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device)) device=device))
...@@ -159,6 +162,7 @@ class BlockTable: ...@@ -159,6 +162,7 @@ class BlockTable:
the current instance. the current instance.
""" """
assert self._is_allocated assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1]) forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable( return BlockTable(
block_size=self._block_size, block_size=self._block_size,
...@@ -177,10 +181,10 @@ class BlockTable: ...@@ -177,10 +181,10 @@ class BlockTable:
assert self._is_allocated assert self._is_allocated
for block in self._blocks: for block in self._blocks:
self._allocator.free(block) self._allocator.free(block)
self._blocks = None self._blocks = []
@property @property
def physical_block_ids(self) -> List[int]: def physical_block_ids(self) -> List[Optional[int]]:
"""Returns a list of physical block indices for the blocks in the """Returns a list of physical block indices for the blocks in the
BlockTable. BlockTable.
...@@ -235,7 +239,7 @@ class BlockTable: ...@@ -235,7 +239,7 @@ class BlockTable:
def _get_all_token_ids(self) -> List[int]: def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly. # NOTE: This function is O(seq_len); use sparingly.
token_ids = [] token_ids: List[int] = []
if not self._is_allocated: if not self._is_allocated:
return token_ids return token_ids
...@@ -247,7 +251,7 @@ class BlockTable: ...@@ -247,7 +251,7 @@ class BlockTable:
@property @property
def _is_allocated(self) -> bool: def _is_allocated(self) -> bool:
return self._blocks is not None return len(self._blocks) > 0
@property @property
def _num_empty_slots(self) -> int: def _num_empty_slots(self) -> int:
......
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional, Protocol
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
...@@ -7,7 +7,19 @@ BlockId = int ...@@ -7,7 +7,19 @@ BlockId = int
RefCount = int RefCount = int
class RefCounter: class RefCounterProtocol(Protocol):
def incr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def decr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def get(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
class RefCounter(RefCounterProtocol):
"""A class for managing reference counts for a set of block indices. """A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their The RefCounter class maintains a dictionary that maps block indices to their
...@@ -54,7 +66,7 @@ class RefCounter: ...@@ -54,7 +66,7 @@ class RefCounter:
return ReadOnlyRefCounter(self) return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter: class ReadOnlyRefCounter(RefCounterProtocol):
"""A read-only view of the RefCounter class. """A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the The ReadOnlyRefCounter class provides a read-only interface to access the
...@@ -96,7 +108,7 @@ class CopyOnWriteTracker: ...@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def __init__( def __init__(
self, self,
refcounter: RefCounter, refcounter: RefCounterProtocol,
allocator: BlockAllocator, allocator: BlockAllocator,
): ):
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
......
from typing import Dict, List, Optional from typing import Dict, FrozenSet, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator, from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator) DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
...@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids = block_ids[num_gpu_blocks:] cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive": if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator( gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks, num_blocks=num_gpu_blocks,
block_size=block_size, block_size=block_size,
block_ids=gpu_block_ids, block_ids=gpu_block_ids,
) )
cpu_allocator = NaiveBlockAllocator( cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks, num_blocks=num_cpu_blocks,
block_size=block_size, block_size=block_size,
block_ids=cpu_block_ids, block_ids=cpu_block_ids,
...@@ -105,13 +105,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -105,13 +105,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.GPU: gpu_block_allocator, Device.GPU: gpu_block_allocator,
} }
self._block_ids_to_allocator = {} self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items(): for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids: for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self,
device: Device) -> Block: prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block on the specified device. """Allocates a new mutable block on the specified device.
Args: Args:
...@@ -122,10 +123,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -122,10 +123,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is not None
return self._allocators[device].allocate_mutable(prev_block) return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int], device: Device) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the """Allocates a new immutable block with the provided token IDs on the
specified device. specified device.
...@@ -140,6 +144,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -140,6 +144,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided Block: The newly allocated immutable block containing the provided
token IDs. token IDs.
""" """
assert device is not None
return self._allocators[device].allocate_immutable( return self._allocators[device].allocate_immutable(
prev_block, token_ids) prev_block, token_ids)
...@@ -149,7 +154,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -149,7 +154,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args: Args:
block (Block): The block to be freed. block (Block): The block to be freed.
""" """
allocator = self._block_ids_to_allocator[block.block_id] block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block) return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
...@@ -163,19 +170,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -163,19 +170,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the List[Block]: A new list of blocks that shares the same memory as the
original sequence. original sequence.
""" """
allocator = self._block_ids_to_allocator[last_block.block_id] block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block) return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
"""Returns the number of free blocks available on the specified device. """Returns the number of free blocks available on the specified device.
Args: Args:
device (Device): The device for which to query the number of free device (Device): The device for which to query the number of free
blocks. blocks. AssertionError is raised if None is passed.
Returns: Returns:
int: The number of free blocks available on the specified device. int: The number of free blocks available on the specified device.
""" """
assert device is not None
return self._allocators[device].get_num_free_blocks() return self._allocators[device].get_num_free_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]: def clear_copy_on_writes(self) -> Dict[int, List[int]]:
...@@ -210,5 +220,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -210,5 +220,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return self._allocators[device].get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids(
seq_block_ids) seq_block_ids)
def all_block_ids(self) -> frozenset[int]: @property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys()) return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
...@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol ...@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device from vllm.utils import Device
BlockId = int
class Block(ABC): class Block(ABC):
...@@ -15,6 +17,12 @@ class Block(ABC): ...@@ -15,6 +17,12 @@ class Block(ABC):
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
pass pass
@block_id.setter
@abstractmethod
def block_id(self, value: Optional[int]) -> None:
"""NOTE: Do not use this API outside Block."""
self._block_id = value
@property @property
@abstractmethod @abstractmethod
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
...@@ -35,6 +43,27 @@ class Block(ABC): ...@@ -35,6 +43,27 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
@property
@abstractmethod
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
@abstractmethod
def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError
@property
@abstractmethod
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
@abstractmethod
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
class Factory(Protocol): class Factory(Protocol):
@abstractmethod @abstractmethod
...@@ -48,6 +77,17 @@ class Block(ABC): ...@@ -48,6 +77,17 @@ class Block(ABC):
) -> "Block": ) -> "Block":
pass pass
@property
@abstractmethod
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return None
class BlockAllocator(ABC): class BlockAllocator(ABC):
...@@ -57,7 +97,7 @@ class BlockAllocator(ABC): ...@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block: token_ids: List[int]) -> Block:
pass pass
@abstractmethod @abstractmethod
...@@ -69,7 +109,7 @@ class BlockAllocator(ABC): ...@@ -69,7 +109,7 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self) -> int:
pass pass
@property @property
...@@ -82,11 +122,12 @@ class BlockAllocator(ABC): ...@@ -82,11 +122,12 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def mark_blocks_as_accessed(self) -> None: def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass pass
@abstractmethod @abstractmethod
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod @abstractmethod
...@@ -94,21 +135,66 @@ class BlockAllocator(ABC): ...@@ -94,21 +135,66 @@ class BlockAllocator(ABC):
self, seq_block_ids: List[List[int]]) -> List[int]: self, seq_block_ids: List[List[int]]) -> List[int]:
pass pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
class NoFreeBlocksError(ValueError): class NoFreeBlocksError(ValueError):
pass pass
class DeviceAwareBlockAllocator(BlockAllocator): class DeviceAwareBlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int], device: Device) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass pass
from typing import Dict, Iterable, List, Optional, Set from typing import Dict, FrozenSet, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
BlockId = int
Refcount = int Refcount = int
...@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator=self, allocator=self,
) )
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int]) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to """Allocates a new immutable block with the given token IDs, linked to
the previous block. the previous block.
...@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The newly allocated immutable block. Block: The newly allocated immutable block.
""" """
assert device is None
block = self.allocate_mutable(prev_block=prev_block) block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
return block return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block. """Allocates a new mutable block, linked to the previous block.
Args: Args:
...@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is None
block_id = self._allocate_new_block_id() block_id = self._allocate_new_block_id()
return self._create_block( return self._create_block(
prev_block=prev_block, prev_block=prev_block,
...@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
) )
def free(self, block: Block) -> None: def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id) self._free_block_id(block.block_id)
# Mark the block as having no allocation. # Mark the block as having no allocation.
...@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for block in source_blocks: for block in source_blocks:
# Increment refcount for each block. # Increment refcount for each block.
assert block.block_id is not None
refcount = self._refcounter.incr(block.block_id) refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block" assert refcount != 1, "can't fork free'd block"
...@@ -126,7 +133,8 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -126,7 +133,8 @@ class NaiveBlockAllocator(BlockAllocator):
return forked_blocks return forked_blocks
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
return len(self._free_block_indices) return len(self._free_block_indices)
def _allocate_new_block_id(self) -> BlockId: def _allocate_new_block_id(self) -> BlockId:
...@@ -148,7 +156,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -148,7 +156,7 @@ class NaiveBlockAllocator(BlockAllocator):
return self._refcounter return self._refcounter
@property @property
def all_block_ids(self): def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
...@@ -200,6 +208,9 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -200,6 +208,9 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
return [] return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
class NaiveBlock(Block): class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix """An implementation of the Block class that does not support prefix
...@@ -224,13 +235,13 @@ class NaiveBlock(Block): ...@@ -224,13 +235,13 @@ class NaiveBlock(Block):
""" """
def __init__(self, def __init__(self,
prev_block: Block, prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
_cow_target: Optional[Block] = None): _cow_target: Optional[Block] = None):
self._token_ids = [] self._token_ids: List[int] = []
self._block_size = block_size self._block_size = block_size
self._prev_block = prev_block self._prev_block = prev_block
self._block_id = block_id self._block_id = block_id
...@@ -256,6 +267,22 @@ class NaiveBlock(Block): ...@@ -256,6 +267,22 @@ class NaiveBlock(Block):
assert self.num_empty_slots >= len(token_ids) assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids) self._token_ids.extend(token_ids)
@property
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
def computed(self, value) -> None:
raise NotImplementedError
@property
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
@property @property
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
return self._block_id return self._block_id
...@@ -276,9 +303,14 @@ class NaiveBlock(Block): ...@@ -276,9 +303,14 @@ class NaiveBlock(Block):
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
return self._token_ids return self._token_ids
@property
def block_size(self) -> int: def block_size(self) -> int:
return self._block_size return self._block_size
@property @property
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
return self._prev_block return self._prev_block
@property
def content_hash(self) -> Optional[int]:
return None
"""Token blocks.""" """Token blocks."""
from itertools import takewhile from itertools import takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, Iterable, List, Optional from typing import Dict, FrozenSet, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker, from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int PrefixHash = int
BlockId = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
...@@ -38,7 +37,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -38,7 +37,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
block_ids: Optional[Iterable[int]] = None, block_ids: Optional[Iterable[int]] = None,
eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU, eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
): ):
# A mapping of prefix hash to block index. All blocks which have a # A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
...@@ -49,7 +48,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -49,7 +48,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# An allocator for blocks that do not have prefix hashes. # An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator( self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block, create_block=self._create_block, # type: ignore
num_blocks=num_blocks, num_blocks=num_blocks,
block_size=block_size, block_size=block_size,
block_ids=block_ids, block_ids=block_ids,
...@@ -79,7 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -79,7 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False, computed: bool = False,
) -> Block: ) -> Block:
# Bind block to self. # Bind block to self.
allocator = self allocator = self
...@@ -93,8 +92,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -93,8 +92,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
computed=computed, computed=computed,
) )
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int]) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached """Allocates an immutable block with the given token IDs, reusing cached
blocks if possible. blocks if possible.
...@@ -105,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -105,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The allocated immutable block. Block: The allocated immutable block.
""" """
assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
block = self._create_block( block = self._create_block(
...@@ -127,16 +129,20 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -127,16 +129,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return block return block
def allocate_mutable(self, prev_block: Block) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will """Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks. evict unused cached blocks.
Args: Args:
prev_block (Block): The previous block in the sequence. prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns: Returns:
Block: The allocated mutable block. Block: The allocated mutable block.
""" """
assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
try: try:
...@@ -144,6 +150,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -144,6 +150,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
prev_block=prev_block) prev_block=prev_block)
assert block.block_id not in self._blocks assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block self._blocks[block.block_id] = block
return block return block
except BlockAllocator.NoFreeBlocksError: except BlockAllocator.NoFreeBlocksError:
...@@ -183,6 +190,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -183,6 +190,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert block.content_hash is None assert block.content_hash is None
assert block.block_id not in self._blocks assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block self._blocks[block.block_id] = block
return block return block
...@@ -225,6 +233,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -225,6 +233,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# We have fork case where block would get more than one ref, # We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1 # so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1: if refcount <= 1:
assert block.block_id is not None
del self._blocks[block.block_id] del self._blocks[block.block_id]
return self._hashless_allocator.free(block) return self._hashless_allocator.free(block)
...@@ -233,6 +242,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -233,6 +242,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# If no longer used, add the block to the evictor. # If no longer used, add the block to the evictor.
if refcount == 0: if refcount == 0:
assert block.content_hash in self._cached_blocks assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id] del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash, self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed) block.num_tokens_total, block.last_accessed)
...@@ -268,18 +278,18 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -268,18 +278,18 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return forked_blocks return forked_blocks
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks # The number of free blocks is the number of hashless free blocks
# plus the number of blocks evictor could free from its list. # plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks( return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks ) + self.evictor.num_blocks
@property @property
def all_block_ids(self) -> frozenset[int]: def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self, def promote_to_immutable_block(self, block: Block) -> BlockId:
block: "PrefixCachingBlock") -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable """Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks block. This means that its content can be referenced by future blocks
having the same prefix. having the same prefix.
...@@ -289,7 +299,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -289,7 +299,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block. block.
Args: Args:
block (PrefixCachingBlock): The mutable block to be promoted. block: The mutable block to be promoted.
Returns: Returns:
BlockId: Either the original block index, or the block index of BlockId: Either the original block index, or the block index of
...@@ -385,8 +395,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -385,8 +395,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
takewhile(lambda block_id: self.block_is_computed(block_id), takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids seq[:-1])) for seq in seq_block_ids
] ]
res = commonprefix([ids for ids in ids_list if ids != []]) # It returns a list of int although type annotation says list of string.
return res return commonprefix([
ids for ids in ids_list # type: ignore
if ids != []
])
class PrefixCachingBlock(Block): class PrefixCachingBlock(Block):
...@@ -403,7 +416,7 @@ class PrefixCachingBlock(Block): ...@@ -403,7 +416,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block. token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in block_size (int): The maximum number of token IDs that can be stored in
the block. the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block. caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index block_id (Optional[int], optional): The physical block index
of this block. Defaults to None. of this block. Defaults to None.
...@@ -411,21 +424,25 @@ class PrefixCachingBlock(Block): ...@@ -411,21 +424,25 @@ class PrefixCachingBlock(Block):
def __init__( def __init__(
self, self,
prev_block: Optional["PrefixCachingBlock"], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
block_size: int, block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator, prefix_caching_allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False, computed: bool = False,
): ):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator self._prefix_caching_allocator = prefix_caching_allocator
self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self.computed = computed self._computed = computed
self._block = NaiveBlock( self._block = NaiveBlock(
prev_block=prev_block, prev_block=prev_block,
...@@ -436,6 +453,22 @@ class PrefixCachingBlock(Block): ...@@ -436,6 +453,22 @@ class PrefixCachingBlock(Block):
_cow_target=self, _cow_target=self,
) )
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as """Appends the given token IDs to the block and registers the block as
immutable if the block becomes full. immutable if the block becomes full.
...@@ -483,7 +516,7 @@ class PrefixCachingBlock(Block): ...@@ -483,7 +516,7 @@ class PrefixCachingBlock(Block):
if self._cached_num_tokens_total is not None: if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total return self._cached_num_tokens_total
_block = self _block: Optional[Block] = self
self._cached_num_tokens_total = 0 self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future # TODO: current implement here take O(N^2), we expect future
...@@ -524,8 +557,10 @@ class PrefixCachingBlock(Block): ...@@ -524,8 +557,10 @@ class PrefixCachingBlock(Block):
return None return None
is_first_block = self._prev_block is None is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else prev_block_hash = (
self._prev_block.content_hash) None if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash. # Previous block exists but does not yet have a hash.
# Return no hash in this case. # Return no hash in this case.
......
...@@ -190,7 +190,7 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -190,7 +190,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert seq.seq_id in self.block_tables assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids) assert all(b is not None for b in block_ids)
return block_ids return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed # Update the last accessed time of all the blocks accessed
...@@ -204,7 +204,9 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -204,7 +204,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_ids = [] block_ids = []
for block_id in block_table.physical_block_ids: for block_id in block_table.physical_block_ids:
block_ids.append(block_id) block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(block_ids, now) self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching, # The only need for mark block as computed is for prefix caching,
...@@ -227,8 +229,9 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -227,8 +229,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq_block_ids = [ seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
] ]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids( return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
......
...@@ -32,15 +32,20 @@ class Evictor(ABC): ...@@ -32,15 +32,20 @@ class Evictor(ABC):
@abstractmethod @abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction""" """Adds block to the evictor, making it a candidate for eviction"""
pass pass
@abstractmethod @abstractmethod
def update(self, block_id: int, last_accessed: int): def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata""" """Update corresponding block's access time in metadata"""
pass pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@abstractproperty @abstractproperty
def num_blocks(self) -> int: def num_blocks(self) -> int:
pass pass
...@@ -55,7 +60,7 @@ class BlockMetaData(): ...@@ -55,7 +60,7 @@ class BlockMetaData():
""" """
def __init__(self, content_hash: int, num_hashed_tokens: int, def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
self.content_hash = content_hash self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed self.last_accessed = last_accessed
...@@ -96,12 +101,12 @@ class LRUEvictor(Evictor): ...@@ -96,12 +101,12 @@ class LRUEvictor(Evictor):
return evicted_block_id, evicted_block.content_hash return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash, self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens, num_hashed_tokens,
last_accessed) last_accessed)
def update(self, block_id: int, last_accessed: int): def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int): def remove(self, block_id: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment