Unverified Commit 14ccd94c authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Core][Bugfix]Refactor block manager for better testability (#3492)

parent 8267b06c
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
...@@ -7,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple ...@@ -7,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
...@@ -196,21 +196,7 @@ class UncachedBlockAllocator(BlockAllocatorBase): ...@@ -196,21 +196,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
"Invalid codepath for uncached block allocator.") "Invalid codepath for uncached block allocator.")
class AllocStatus(enum.Enum): class BlockSpaceManagerV1(BlockSpaceManager):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager:
"""Manages the mapping between logical and physical token blocks.""" """Manages the mapping between logical and physical token blocks."""
def __init__( def __init__(
...@@ -355,6 +341,11 @@ class BlockSpaceManager: ...@@ -355,6 +341,11 @@ class BlockSpaceManager:
self, self,
seq: Sequence, seq: Sequence,
) -> PhysicalTokenBlock: ) -> PhysicalTokenBlock:
# Called before a new block is appended.
# This is in charge of allocating a new physical block (to be appended).
# None if the last block is not full. Otherwise, we set it to the
# content hash.
if not self.enable_caching: if not self.enable_caching:
return self.gpu_allocator.allocate() return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash: Optional[int] = None
...@@ -362,7 +353,14 @@ class BlockSpaceManager: ...@@ -362,7 +353,14 @@ class BlockSpaceManager:
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block( num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1) len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# prefix tokens)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
# If the block has is None, then the block is not full.
# If the block is not full, then we expect it to have a refcount of 1.
if block_hash is None: if block_hash is None:
assert new_block.ref_count == 1 assert new_block.ref_count == 1
return new_block return new_block
...@@ -576,16 +574,16 @@ class BlockSpaceManager: ...@@ -576,16 +574,16 @@ class BlockSpaceManager:
for b in takewhile(lambda b: b.computed, block_table[:-1]) for b in takewhile(lambda b: b.computed, block_table[:-1])
] ]
def get_common_computed_block_ids(self, def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
seq_group: SequenceGroup) -> List[int]: """Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# Can return non-empty result only with prefix caching enabled. # Can return non-empty result only with prefix caching enabled.
if not self.enable_caching: if not self.enable_caching:
return [] return []
ids_list = [ ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values())
]
return commonprefix([ids for ids in ids_list if ids != []]) return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
......
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional, Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
SeqId = int
class BlockSpaceManagerV2(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for
autoregressively-generated tokens, and other advanced features such as
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
The current implementation is partial; in particular prefix caching and
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
watermark (float, optional): The threshold used for memory swapping.
Defaults to 0.01.
sliding_window (Optional[int], optional): The size of the sliding
window. Defaults to None.
enable_caching (bool, optional): Flag indicating whether caching is
enabled. Defaults to False.
"""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
assert sliding_window is None, "Sliding window not yet supported"
self.block_sliding_window = None
self.watermark = watermark
assert watermark >= 0.0
assert not enable_caching, "Prefix caching not yet supported"
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create(
# Currently, only naive blocks are supported (no prefix caching).
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
)
assert self.block_sliding_window is None
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists"
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = waiting_seqs[0]
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
)
assert self.block_sliding_window is None
block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]
# Get unseen token ids.
num_full_slots = block_table.num_full_slots
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
assert unseen_token_ids
block_table.append_token_ids(unseen_token_ids)
# Return any copy-on-writes.
_ = self.block_allocator.clear_copy_on_writes()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
return None
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids
def access_all_blocks_in_seq(self, seq, now):
# TODO add prefix caching support.
# Tracked here https://github.com/vllm-project/vllm/issues/3667
pass
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# We ignore the sequence group as its not necessary. After the batch is
# formed by the scheduler, we do not need to mark blocks from individual
# sequence groups as computed -- all blocks in the batch can be marked
# as computed.
self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Currently, the attention implementation only supports skipping cached
blocks if they are a contiguous prefix of cached blocks.
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids)
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
return False
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
import enum
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from vllm.sequence import Sequence, SequenceGroup
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager(ABC):
@staticmethod
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "v1":
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
return BlockSpaceManagerV1
if version == "v2":
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
pass
@abstractmethod
def allocate(self, seq_group: SequenceGroup) -> None:
pass
@abstractmethod
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
pass
@abstractmethod
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def free(self, seq: Sequence) -> None:
pass
@abstractmethod
def get_block_table(self, seq: Sequence) -> List[int]:
pass
@abstractmethod
def get_num_free_gpu_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_cpu_blocks(self) -> int:
pass
@abstractmethod
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
pass
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
...@@ -4,7 +4,7 @@ from collections import deque ...@@ -4,7 +4,7 @@ from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -88,8 +88,13 @@ class Scheduler: ...@@ -88,8 +88,13 @@ class Scheduler:
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs") self.policy = PolicyFactory.get_policy(policy_name="fcfs")
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config.
use_v2_block_manager else "v1")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks, num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
...@@ -378,6 +383,10 @@ class Scheduler: ...@@ -378,6 +383,10 @@ class Scheduler:
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now) self.block_manager.access_all_blocks_in_seq(seq, now)
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run, is_prompt=scheduler_outputs.prompt_run,
...@@ -385,8 +394,7 @@ class Scheduler: ...@@ -385,8 +394,7 @@ class Scheduler:
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=self.block_manager. computed_block_nums=common_computed_block_nums,
get_common_computed_block_ids(seq_group),
state=seq_group.state, state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
...@@ -396,6 +404,14 @@ class Scheduler: ...@@ -396,6 +404,14 @@ class Scheduler:
if scheduler_outputs.prompt_run else None, if scheduler_outputs.prompt_run else None,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(seq_group)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
...@@ -503,9 +519,6 @@ class Scheduler: ...@@ -503,9 +519,6 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED seq.status = SequenceStatus.SWAPPED
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
self.block_manager.mark_blocks_as_computed(seq_group)
def _passed_delay(self, now: float) -> bool: def _passed_delay(self, now: float) -> bool:
if self.prev_prompt: if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time self.last_prompt_latency = now - self.prev_time
......
...@@ -28,6 +28,7 @@ class EngineArgs: ...@@ -28,6 +28,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
block_size: int = 16 block_size: int = 16
enable_prefix_caching: bool = False enable_prefix_caching: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
...@@ -52,6 +53,9 @@ class EngineArgs: ...@@ -52,6 +53,9 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
image_input_type: Optional[str] = None image_input_type: Optional[str] = None
image_token_id: Optional[int] = None image_token_id: Optional[int] = None
...@@ -194,6 +198,9 @@ class EngineArgs: ...@@ -194,6 +198,9 @@ class EngineArgs:
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='Enables automatic prefix caching') help='Enables automatic prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
...@@ -210,6 +217,12 @@ class EngineArgs: ...@@ -210,6 +217,12 @@ class EngineArgs:
help='the fraction of GPU memory to be used for ' help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.' 'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.') 'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--forced-num-gpu-blocks',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens', parser.add_argument('--max-num-batched-tokens',
type=int, type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
...@@ -369,6 +382,7 @@ class EngineArgs: ...@@ -369,6 +382,7 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
self.forced_num_gpu_blocks,
model_config.get_sliding_window(), model_config.get_sliding_window(),
self.enable_prefix_caching) self.enable_prefix_caching)
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
...@@ -383,6 +397,7 @@ class EngineArgs: ...@@ -383,6 +397,7 @@ class EngineArgs:
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.use_v2_block_manager,
self.scheduler_delay_factor) self.scheduler_delay_factor)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
......
...@@ -553,12 +553,6 @@ class LLMEngine: ...@@ -553,12 +553,6 @@ class LLMEngine:
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
# If prefix caching is enabled, mark all blocks in the sequence groups
# as completed so that future requests don't attempt to recompute them
if self.cache_config.enable_prefix_caching:
for seq_group in scheduled_seq_groups:
self.scheduler.mark_blocks_as_computed(seq_group)
for seq_group, outputs in zip(scheduled_seq_groups, output): for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs) self._process_sequence_group_outputs(seq_group, outputs)
......
...@@ -85,6 +85,12 @@ class GPUExecutor(ExecutorBase): ...@@ -85,6 +85,12 @@ class GPUExecutor(ExecutorBase):
cache_dtype=self.cache_config.cache_dtype, cache_dtype=self.cache_config.cache_dtype,
)) ))
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}") f"# CPU blocks: {num_cpu_blocks}")
......
...@@ -232,6 +232,13 @@ class RayGPUExecutor(ExecutorBase): ...@@ -232,6 +232,13 @@ class RayGPUExecutor(ExecutorBase):
# operators can be applied to all workers. # operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}") f"# CPU blocks: {num_cpu_blocks}")
......
...@@ -196,6 +196,8 @@ class Sequence: ...@@ -196,6 +196,8 @@ class Sequence:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
# Compute the number of tokens in the sequence # Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize # TODO: The current hashing function is O(L^2). We should optimize
# this in the future. # this in the future.
......
...@@ -227,6 +227,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None: ...@@ -227,6 +227,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def chunk_list(lst, chunk_size):
"""Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_nvcc_cuda_version() -> Optional[Version]: def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME') cuda_home = os.environ.get('CUDA_HOME')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment