Unverified Commit e02ac556 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[Performance] Optimize e2e overheads: Reduce python allocations (#7162)

parent 73388c07
...@@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder( ...@@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder(
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt) elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None): and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:] if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table) self.block_tables.append(block_table)
# Compute slot mapping. # Compute slot mapping.
......
...@@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], ...@@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# tokens are masked and the slot mapping will be # tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len): def add_slot(i):
block_number = block_table[i // block_size] block_number = block_table[i // block_size]
block_offset = i % block_size block_offset = i % block_size
slot = block_number * block_size + block_offset slot = block_number * block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if start_idx == 0 and (seq_len - context_len) == 1:
# Optimization for common-case of decoding next token
add_slot(seq_len - 1)
else:
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
add_slot(i)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
......
"""Token blocks.""" """Token blocks."""
from typing import List from typing import List, Optional
from vllm.utils import Device from vllm.utils import Device
...@@ -37,5 +37,47 @@ class PhysicalTokenBlock: ...@@ -37,5 +37,47 @@ class PhysicalTokenBlock:
f'computed={self.computed})') f'computed={self.computed})')
# Mapping: logical block number -> physical block. class BlockTable:
BlockTable = List[PhysicalTokenBlock] """Holds a list of blocks with caching of their associated block_ids
"""
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
self._blocks: List[PhysicalTokenBlock] = []
self._block_ids: List[int] = []
if blocks is not None:
for block in blocks:
self.append(block)
def append(self, block: PhysicalTokenBlock):
self._blocks.append(block)
self._block_ids.append(block.block_number)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, key):
return self._blocks[key]
def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
self._blocks[key] = blocks
self._block_ids[key] = [b.block_number for b in blocks]
else:
block = value
self._blocks[key] = block
self._block_ids[key] = block.block_number
def reset(self):
self._blocks = []
self._block_ids = []
def copy(self) -> "BlockTable":
return BlockTable(self._blocks)
def list(self) -> List[PhysicalTokenBlock]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids
...@@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase): ...@@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
self.num_blocks = num_blocks self.num_blocks = num_blocks
# Initialize the free blocks. # Initialize the free blocks.
self.free_blocks: BlockTable = [] self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock(device=device, block = PhysicalTokenBlock(device=device,
block_number=i, block_number=i,
...@@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
Device.CPU, block_size, num_cpu_blocks) Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
# Mapping: req_id -> BlockTable # Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique # Note that each SequenceGroup has a unique
# request ID # request ID
...@@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks num_prompt_blocks = seq.n_blocks
block_table: BlockTable = [] block_table: BlockTable = BlockTable()
for logical_idx in range(num_prompt_blocks): for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
...@@ -326,15 +327,19 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -326,15 +327,19 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# #
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
# decoder prompt. # decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
seq = wait_seqs[0]
block_table: BlockTable = \ block_table: BlockTable = \
self._allocate_sequence(seq, self._allocate_sequence(seq,
seq_group.num_seqs(), seq_group.num_seqs(),
is_encoder_decoder) is_encoder_decoder)
# Assign the self-attention block tables for each sequence. # Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): if len(wait_seqs) == 1:
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[wait_seqs[0].seq_id] = block_table
else:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
# Allocate encoder sequence # Allocate encoder sequence
if is_encoder_decoder: if is_encoder_decoder:
...@@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
return return
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy() self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused. # When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks. # In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count` # When forking, we must make sure that each block's `ref_count`
...@@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
dest_allocator: BlockAllocatorBase, dest_allocator: BlockAllocatorBase,
mapping: Dict[PhysicalTokenBlock, mapping: Dict[PhysicalTokenBlock,
PhysicalTokenBlock]) -> BlockTable: PhysicalTokenBlock]) -> BlockTable:
new_block_table = [] new_block_table: BlockTable = BlockTable()
for from_block in block_table: for from_block in block_table:
if from_block in mapping: if from_block in mapping:
...@@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id], self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator, self.cpu_allocator, self.gpu_allocator,
self.gpu_allocator,
mapping) mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
...@@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id], self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator, self.gpu_allocator, self.cpu_allocator,
self.cpu_allocator,
mapping) mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
...@@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables.clear() self.cross_block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]: def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id] return self.block_tables[seq.seq_id].ids()
return [block.block_number for block in block_table]
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
block_table = self.cross_block_tables[seq_group.request_id] block_table = self.cross_block_tables[seq_group.request_id]
......
...@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest ...@@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
from vllm.utils import PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -176,10 +177,10 @@ class SchedulerRunningOutputs: ...@@ -176,10 +177,10 @@ class SchedulerRunningOutputs:
enough memory, it can be preempted (for recompute) or swapped out. enough memory, it can be preempted (for recompute) or swapped out.
""" """
# Selected sequences that are running and in a decoding phase. # Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[SequenceGroup] decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase. # Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked. # I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup] prefill_seq_groups: List[ScheduledSequenceGroup]
# The preempted sequences. # The preempted sequences.
preempted: List[SequenceGroup] preempted: List[SequenceGroup]
# Sequences that are swapped out. # Sequences that are swapped out.
...@@ -191,6 +192,10 @@ class SchedulerRunningOutputs: ...@@ -191,6 +192,10 @@ class SchedulerRunningOutputs:
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# Optimization for fast-access to seq_group lists
decode_seq_groups_list: List[SequenceGroup]
prefill_seq_groups_list: List[SequenceGroup]
@classmethod @classmethod
def create_empty(cls) -> "SchedulerRunningOutputs": def create_empty(cls) -> "SchedulerRunningOutputs":
return SchedulerRunningOutputs( return SchedulerRunningOutputs(
...@@ -201,6 +206,8 @@ class SchedulerRunningOutputs: ...@@ -201,6 +206,8 @@ class SchedulerRunningOutputs:
blocks_to_swap_out=[], blocks_to_swap_out=[],
blocks_to_copy=[], blocks_to_copy=[],
num_lookahead_slots=0, num_lookahead_slots=0,
decode_seq_groups_list=[],
prefill_seq_groups_list=[],
) )
...@@ -259,6 +266,30 @@ class SchedulerPrefillOutputs: ...@@ -259,6 +266,30 @@ class SchedulerPrefillOutputs:
) )
def seq_group_metadata_builder():
return SequenceGroupMetadata(request_id="",
is_prompt=False,
seq_data={},
sampling_params=None,
block_tables={})
def scheduler_running_outputs_builder():
return SchedulerRunningOutputs(decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
prefill_seq_groups_list=[],
decode_seq_groups_list=[])
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
class Scheduler: class Scheduler:
def __init__( def __init__(
...@@ -331,6 +362,14 @@ class Scheduler: ...@@ -331,6 +362,14 @@ class Scheduler:
else 0) else 0)
self.num_cumulative_preemption: int = 0 self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
seq_group_metadata_builder)
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
scheduled_seq_group_builder)
@property @property
def lora_enabled(self) -> bool: def lora_enabled(self) -> bool:
return bool(self.lora_config) return bool(self.lora_config)
...@@ -441,14 +480,30 @@ class Scheduler: ...@@ -441,14 +480,30 @@ class Scheduler:
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
""" """
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
ret.prefill_seq_groups.clear()
ret.preempted.clear()
ret.swapped_out.clear()
ret.num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=False)
ret.decode_seq_groups_list.clear()
ret.prefill_seq_groups_list.clear()
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[
preempted: List[SequenceGroup] = [] ScheduledSequenceGroup] = ret.prefill_seq_groups
swapped_out: List[SequenceGroup] = [] preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot # NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state. # to keep all the sequence groups in the RUNNING state.
...@@ -497,15 +552,19 @@ class Scheduler: ...@@ -497,15 +552,19 @@ class Scheduler:
else: else:
self._append_slots(seq_group, blocks_to_copy) self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill() is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill: if is_prefill:
prefill_seq_groups.append( scheduled_seq_group.token_chunk_size = num_running_tokens
ScheduledSequenceGroup( prefill_seq_groups.append(scheduled_seq_group)
seq_group=seq_group, ret.prefill_seq_groups_list.append(seq_group)
token_chunk_size=num_running_tokens))
else: else:
decode_seq_groups.append( scheduled_seq_group.token_chunk_size = 1
ScheduledSequenceGroup(seq_group=seq_group, decode_seq_groups.append(scheduled_seq_group)
token_chunk_size=1)) ret.decode_seq_groups_list.append(seq_group)
budget.add_num_batched_tokens(seq_group.request_id, budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens) num_running_tokens)
# OPTIMIZATION: Note that get_max_num_running_seqs is # OPTIMIZATION: Note that get_max_num_running_seqs is
...@@ -518,15 +577,10 @@ class Scheduler: ...@@ -518,15 +577,10 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
return SchedulerRunningOutputs( self._scheduler_running_outputs_cache.reset()
decode_seq_groups=decode_seq_groups, self._scheduled_seq_group_cache.reset()
prefill_seq_groups=prefill_seq_groups,
preempted=preempted, return ret
swapped_out=swapped_out,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False))
def _schedule_swapped( def _schedule_swapped(
self, self,
...@@ -820,11 +874,15 @@ class Scheduler: ...@@ -820,11 +874,15 @@ class Scheduler:
# Update waiting requests. # Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups]) if len(prefills.seq_groups) > 0:
self.running.extend( self.running.extend([s.seq_group for s in prefills.seq_groups])
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend( self.running.extend(running_scheduled.decode_seq_groups_list)
[s.seq_group for s in swapped_in.decode_seq_groups])
if len(swapped_in.decode_seq_groups) > 0:
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests. # Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) + preempted = (len(running_scheduled.preempted) +
...@@ -834,18 +892,30 @@ class Scheduler: ...@@ -834,18 +892,30 @@ class Scheduler:
# doesn't allow chunked prefills. # doesn't allow chunked prefills.
assert len(running_scheduled.prefill_seq_groups) == 0 assert len(running_scheduled.prefill_seq_groups) == 0
assert len(swapped_in.prefill_seq_groups) == 0 assert len(swapped_in.prefill_seq_groups) == 0
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=scheduled_seq_groups,
running_scheduled.decode_seq_groups + num_prefill_groups=num_prefill_groups,
swapped_in.decode_seq_groups),
num_prefill_groups=len(prefills.seq_groups),
num_batched_tokens=budget.num_batched_tokens, num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy + blocks_to_copy=blocks_to_copy,
swapped_in.blocks_to_copy, ignored_seq_groups=ignored_seq_groups,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running), running_queue_size=len(self.running),
preempted=preempted, preempted=preempted,
...@@ -963,6 +1033,9 @@ class Scheduler: ...@@ -963,6 +1033,9 @@ class Scheduler:
scheduler_outputs = self._schedule() scheduler_outputs = self._schedule()
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate( for i, scheduled_seq_group in enumerate(
...@@ -971,10 +1044,15 @@ class Scheduler: ...@@ -971,10 +1044,15 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache.get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData # seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
# seq_id -> physical block numbers # seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int,
List[int]] = seq_group_metadata.block_tables
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup # Encoder associated with SequenceGroup
...@@ -993,9 +1071,10 @@ class Scheduler: ...@@ -993,9 +1071,10 @@ class Scheduler:
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now) self.block_manager.access_all_blocks_in_seq(seq, now)
common_computed_block_nums = ( if self.cache_config.enable_prefix_caching:
self.block_manager.get_common_computed_block_ids( common_computed_block_nums = (
seq_group.get_seqs(status=SequenceStatus.RUNNING))) self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True do_sample = True
if seq_group.is_prefill(): if seq_group.is_prefill():
...@@ -1014,7 +1093,8 @@ class Scheduler: ...@@ -1014,7 +1093,8 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = seq_group.is_prefill() is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata(
seq_group_metadata.__init__(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
seq_data=seq_data, seq_data=seq_data,
...@@ -1045,6 +1125,8 @@ class Scheduler: ...@@ -1045,6 +1125,8 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed( self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group) scheduled_seq_group.seq_group)
self._seq_group_metadata_cache.reset()
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
...@@ -1093,7 +1175,8 @@ class Scheduler: ...@@ -1093,7 +1175,8 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots) cows = self.block_manager.append_slots(seq, num_lookahead_slots)
blocks_to_copy.extend(cows) if len(cows) > 0:
blocks_to_copy.extend(cows)
def _preempt( def _preempt(
self, self,
......
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter) PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingMetadataCache)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"SamplingMetadata", "SamplingMetadata",
"SamplingMetadataCache",
"set_random_seed", "set_random_seed",
"BasevLLMParameter", "BasevLLMParameter",
"PackedvLLMParameter", "PackedvLLMParameter",
......
...@@ -8,8 +8,9 @@ import torch ...@@ -8,8 +8,9 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (PyObjectCache, async_tensor_h2d,
make_tensor_with_pad, maybe_expand_dim) is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
...@@ -62,6 +63,39 @@ class SequenceGroupToSample: ...@@ -62,6 +63,39 @@ class SequenceGroupToSample:
assert self.query_len is not None assert self.query_len is not None
def gen_seq_group_to_sample_builder(num_seqs: int):
return lambda: SequenceGroupToSample(
seq_ids=[0] * num_seqs,
sampling_params=None,
seq_data=None, # type: ignore
seq_len=0,
query_len=0,
generator=None,
is_prompt=True,
prompt_logprob_indices=[],
sample_indices=[])
class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
def get_cached_seq_group_to_sample(self, num_seqs):
if num_seqs not in self._seq_group_to_sample_cache:
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
gen_seq_group_to_sample_builder(num_seqs))
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
return obj
def reset(self):
for cache in self._seq_group_to_sample_cache.values():
cache.reset()
class SamplingMetadata: class SamplingMetadata:
"""Metadata for input sequences. Used in sampler. """Metadata for input sequences. Used in sampler.
...@@ -121,6 +155,7 @@ class SamplingMetadata: ...@@ -121,6 +155,7 @@ class SamplingMetadata:
device: str, device: str,
pin_memory: bool, pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata": ) -> "SamplingMetadata":
( (
seq_groups, seq_groups,
...@@ -128,7 +163,7 @@ class SamplingMetadata: ...@@ -128,7 +163,7 @@ class SamplingMetadata:
categorized_sample_indices, categorized_sample_indices,
num_prompts, num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators) device, generators, cache)
selected_token_indices = async_tensor_h2d(selected_token_indices, selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=device, target_device=device,
...@@ -164,6 +199,7 @@ def _prepare_seq_groups( ...@@ -164,6 +199,7 @@ def _prepare_seq_groups(
query_lens: Optional[List[int]], query_lens: Optional[List[int]],
device: str, device: str,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]: SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling. """Prepare sequence groups and indices for sampling.
...@@ -210,15 +246,27 @@ def _prepare_seq_groups( ...@@ -210,15 +246,27 @@ def _prepare_seq_groups(
num_prompts = 0 num_prompts = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = seq_group_metadata.seq_data.keys()
if cache is not None:
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
for j, seq_id in enumerate(seq_ids):
sample_obj.seq_ids[j] = seq_id
sample_obj.prompt_logprob_indices.clear()
sample_obj.sample_indices.clear()
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None. # If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None seq_len: Optional[int] = None
query_len: Optional[int] = None query_len: Optional[int] = None
prompt_logprob_indices: List[int] = [] prompt_logprob_indices: List[int] = \
sample_indices: List[int] = [] sample_obj.prompt_logprob_indices if cache is not None else []
sample_indices: List[int] = \
sample_obj.sample_indices if cache is not None else []
do_sample = seq_group_metadata.do_sample do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt: if seq_group_metadata.is_prompt:
...@@ -290,9 +338,16 @@ def _prepare_seq_groups( ...@@ -290,9 +338,16 @@ def _prepare_seq_groups(
logit_idx += sample_len logit_idx += sample_len
sample_idx += sample_len sample_idx += sample_len
seq_groups.append( if cache is not None:
SequenceGroupToSample( sample_obj.sampling_params = sampling_params
seq_ids=seq_ids, sample_obj.seq_data = seq_group_metadata.seq_data
sample_obj.seq_len = seq_len
sample_obj.query_len = query_len
sample_obj.generator = generator
sample_obj.is_prompt = is_prompt
else:
sample_obj = SequenceGroupToSample(
seq_ids=list(seq_ids),
sampling_params=sampling_params, sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data, seq_data=seq_group_metadata.seq_data,
seq_len=seq_len, seq_len=seq_len,
...@@ -300,7 +355,13 @@ def _prepare_seq_groups( ...@@ -300,7 +355,13 @@ def _prepare_seq_groups(
generator=generator, generator=generator,
is_prompt=is_prompt, is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices), prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices))) sample_indices=list(sample_indices))
seq_groups.append(sample_obj)
if cache is not None:
cache.reset()
return (seq_groups, selected_token_indices, categorized_sample_indices, return (seq_groups, selected_token_indices, categorized_sample_indices,
num_prompts) num_prompts)
......
...@@ -139,7 +139,7 @@ class RequestOutput: ...@@ -139,7 +139,7 @@ class RequestOutput:
CompletionOutput( CompletionOutput(
seqs.index(seq), seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length), seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(), seq.data._output_token_ids, # type: ignore
seq.get_cumulative_logprob() if include_logprobs else None, seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None, seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status), SequenceStatus.get_finished_reason(seq.status),
......
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
...@@ -330,7 +329,7 @@ class Sequence: ...@@ -330,7 +329,7 @@ class Sequence:
@property @property
def n_blocks(self) -> int: def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size) return (self.get_len() + self.block_size - 1) // self.block_size
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
...@@ -514,7 +513,9 @@ class SequenceGroup: ...@@ -514,7 +513,9 @@ class SequenceGroup:
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.is_single_seq = len(seqs) == 1
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.metrics = RequestMetrics(arrival_time=arrival_time, self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time, last_token_time=arrival_time,
...@@ -635,6 +636,10 @@ class SequenceGroup: ...@@ -635,6 +636,10 @@ class SequenceGroup:
) -> List[Sequence]: ) -> List[Sequence]:
if status is None: if status is None:
return self.seqs return self.seqs
if self.is_single_seq:
return self.seqs if self.seqs[0].status == status else []
return [seq for seq in self.seqs if seq.status == status] return [seq for seq in self.seqs if seq.status == status]
def is_encoder_decoder(self) -> bool: def is_encoder_decoder(self) -> bool:
...@@ -644,6 +649,9 @@ class SequenceGroup: ...@@ -644,6 +649,9 @@ class SequenceGroup:
return self.encoder_seq return self.encoder_seq
def get_unfinished_seqs(self) -> List[Sequence]: def get_unfinished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if not self.seqs[0].is_finished() else []
return [seq for seq in self.seqs if not seq.is_finished()] return [seq for seq in self.seqs if not seq.is_finished()]
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]:
...@@ -668,12 +676,21 @@ class SequenceGroup: ...@@ -668,12 +676,21 @@ class SequenceGroup:
if status is None: if status is None:
return len(self.seqs) return len(self.seqs)
if self.is_single_seq:
return 1 if self.seqs[0].status == status else 0
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int: def num_unfinished_seqs(self) -> int:
if self.is_single_seq:
return 1 if not self.seqs[0].is_finished() else 0
return len(self.get_unfinished_seqs()) return len(self.get_unfinished_seqs())
def num_finished_seqs(self) -> int: def num_finished_seqs(self) -> int:
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs()) return len(self.get_finished_seqs())
def find(self, seq_id: int) -> Sequence: def find(self, seq_id: int) -> Sequence:
...@@ -686,12 +703,14 @@ class SequenceGroup: ...@@ -686,12 +703,14 @@ class SequenceGroup:
raise ValueError(f"Sequence {seq.seq_id} already exists.") raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq self.seqs_dict[seq.seq_id] = seq
self.seqs.append(seq) self.seqs.append(seq)
self.is_single_seq = len(self.seqs) == 1
def remove(self, seq_id: int) -> None: def remove(self, seq_id: int) -> None:
seq = self.seqs_dict.pop(seq_id, None) seq = self.seqs_dict.pop(seq_id, None)
if seq is None: if seq is None:
raise ValueError(f"Sequence {seq_id} not found.") raise ValueError(f"Sequence {seq_id} not found.")
self.seqs.remove(seq) self.seqs.remove(seq)
self.is_single_seq = len(self.seqs) == 1
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
...@@ -775,9 +794,10 @@ class SequenceGroupMetadata: ...@@ -775,9 +794,10 @@ class SequenceGroupMetadata:
# TODO: We should maintain this states out of the sequence group. # TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None self.num_speculative_tokens = None
if self._token_chunk_size is None: if seq_data is not None and self._token_chunk_size is None:
if is_prompt: if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len() self._token_chunk_size = next(iter(
seq_data.values())).get_len()
else: else:
self._token_chunk_size = 1 self._token_chunk_size = 1
......
...@@ -261,6 +261,44 @@ class LRUCache(Generic[T]): ...@@ -261,6 +261,44 @@ class LRUCache(Generic[T]):
self.cache.clear() self.cache.clear()
class PyObjectCache:
"""Used to cache python objects to avoid object allocations
across scheduler iterations.
"""
def __init__(self, obj_builder):
self._obj_builder = obj_builder
self._index = 0
self._obj_cache = []
for _ in range(128):
self._obj_cache.append(self._obj_builder())
def _grow_cache(self):
# Double the size of the cache
num_objs = len(self._obj_cache)
for _ in range(num_objs):
self._obj_cache.append(self._obj_builder())
def get_object(self):
"""Returns a pre-allocated cached object. If there is not enough
objects, then the cache size will double.
"""
if self._index >= len(self._obj_cache):
self._grow_cache()
assert self._index < len(self._obj_cache)
obj = self._obj_cache[self._index]
self._index += 1
return obj
def reset(self):
"""Makes all cached-objects available for the next scheduler iteration.
"""
self._index = 0
def is_hip() -> bool: def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
......
import dataclasses import dataclasses
import gc import gc
import itertools
import time import time
import warnings import warnings
import weakref import weakref
...@@ -35,7 +36,7 @@ from vllm.logger import init_logger ...@@ -35,7 +36,7 @@ from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (supports_lora,
...@@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import (
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists, from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
get_kv_cache_torch_dtype, is_hip, flatten_2d_lists, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available) is_pin_memory_available)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
...@@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
class InterDataForSeqGroup: class InterDataForSeqGroup:
"""Intermediate data for the current sequence group.""" """Intermediate data for the current sequence group."""
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
def __init__( def __init__(
self, self,
*, *,
...@@ -220,35 +235,121 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -220,35 +235,121 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Whether the prefix cache is hit (prefill only). # Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False, prefix_cache_hit: bool = False,
reinit: bool = False,
reinit_use_defaults: bool = False,
): ):
if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore
for i, seq_id in enumerate(seq_ids):
self.seq_ids[i] = seq_id # type: ignore
else:
self.seq_ids = seq_ids
self.request_id = request_id self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt self.is_prompt = is_prompt
self.block_tables = block_tables self.block_tables = block_tables
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request
if reinit:
if len(self.seq_ids) == 1 and reinit_use_defaults:
self.simple_reinit()
else:
if input_tokens:
self.input_tokens = input_tokens
else:
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
if input_positions:
self.input_positions = input_positions
else:
for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear()
if seq_lens:
self.seq_lens = seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.seq_lens[seq_id] = 0
if orig_seq_lens:
self.orig_seq_lens = orig_seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0
if query_lens:
self.query_lens = query_lens
else:
for seq_id in range(len(self.seq_ids)):
self.query_lens[seq_id] = 0
if context_lens:
self.context_lens = context_lens
else:
for seq_id in range(len(self.seq_ids)):
self.context_lens[seq_id] = 0
if curr_sliding_window_blocks:
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks
else:
for seq_id in range(len(self.seq_ids)):
self.curr_sliding_window_blocks[seq_id] = 0
if lora_index_mapping:
self.lora_index_mapping = lora_index_mapping
else:
self.lora_index_mapping.clear()
if lora_prompt_mapping:
self.lora_prompt_mapping = lora_prompt_mapping
else:
self.lora_prompt_mapping.clear()
if lora_requests:
self.lora_requests = lora_requests
else:
self.lora_requests.clear()
if prompt_adapter_index_mapping:
self.prompt_adapter_index_mapping = \
prompt_adapter_index_mapping
else:
self.prompt_adapter_index_mapping.clear()
if prompt_adapter_prompt_mapping:
self.prompt_adapter_prompt_mapping = \
prompt_adapter_prompt_mapping
else:
self.prompt_adapter_prompt_mapping.clear()
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
self.__post_init__() if not reinit:
self.__post_init__()
def __post_init__(self): def __post_init__(self):
self.n_seqs = len(self.seq_ids) self.n_seqs = len(self.seq_ids)
...@@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.context_lens = [0] * self.n_seqs self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = [[] for _ in range(self.n_seqs)] self.lora_index_mapping = []
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)] self.lora_prompt_mapping = []
def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
seq_ids=[0] * num_seqs,
is_prompt=True,
block_tables=None,
computed_block_nums=[])
def init_cached_inter_data(self, *args, **kwargs):
assert len(args) == 0
assert "seq_ids" in kwargs
seq_ids = kwargs["seq_ids"]
num_seqs = len(seq_ids)
# The inter-data cache is per model_runner
inter_data_cache = self.runner.inter_data_cache
if num_seqs not in inter_data_cache:
inter_data_cache[num_seqs] = PyObjectCache(
self.gen_inter_data_builder(num_seqs))
obj = inter_data_cache[num_seqs].get_object()
obj.__init__(*args, **kwargs)
return obj
def reset_cached_inter_data(self):
for cache in self.runner.inter_data_cache.values():
cache.reset()
def __init__(self, def __init__(self,
runner: "GPUModelRunnerBase", runner: "GPUModelRunnerBase",
...@@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute tokens. # Compute tokens.
if inter_data.is_prompt: if inter_data.is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len] tokens = seq_data.get_token_ids()
if context_len != 0 or seq_len < len(tokens):
tokens = tokens[context_len:seq_len]
else: else:
# Optimization. get_token_ids requires the entire copy of # Optimization. get_token_ids requires the entire copy of
# tokens. # tokens.
tokens = [seq_data.get_last_token_id()] tokens = seq_data.get_last_token_id()
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx] = tokens
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) if isinstance(tokens, list):
inter_data.input_tokens[seq_idx].extend(tokens)
else:
inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
...@@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder.""" """Add a sequence group to the builder."""
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids) n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt is_prompt = seq_group_metadata.is_prompt
...@@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
assert n_seqs == 1 assert n_seqs == 1
self.decode_only = False self.decode_only = False
inter_data = self.InterDataForSeqGroup( inter_data = self.init_cached_inter_data(
request_id=seq_group_metadata.request_id, request_id=seq_group_metadata.request_id,
seq_ids=seq_ids, seq_ids=seq_ids,
is_prompt=is_prompt, is_prompt=is_prompt,
block_tables=seq_group_metadata.block_tables, block_tables=seq_group_metadata.block_tables,
computed_block_nums=seq_group_metadata.computed_block_nums) computed_block_nums=seq_group_metadata.computed_block_nums,
reinit=True,
reinit_use_defaults=True)
self.inter_data_list.append(inter_data) self.inter_data_list.append(inter_data)
for seq_idx in range(n_seqs): for seq_idx in range(n_seqs):
...@@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors. create on-device tensors.
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = flatten_2d_lists([ input_tokens = []
flatten_2d_lists(inter_data.input_tokens) for inter_data in self.inter_data_list:
for inter_data in self.inter_data_list for cur_input_tokens in inter_data.input_tokens:
]) input_tokens.extend(cur_input_tokens)
if not input_tokens: if not input_tokens:
# This may happen when all prefill requests hit # This may happen when all prefill requests hit
# prefix caching and there is no decode request. # prefix caching and there is no decode request.
return self.model_input_cls() return self.model_input_cls()
input_positions = flatten_2d_lists([
flatten_2d_lists(inter_data.input_positions) input_positions = []
for inter_data in self.inter_data_list for inter_data in self.inter_data_list:
]) for cur_input_positions in inter_data.input_positions:
input_positions.extend(cur_input_positions)
seq_lens = [] seq_lens = []
max_decode_seq_len = 0 max_decode_seq_len = 0
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
...@@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not inter_data.is_prompt: if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len, max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens)) max(inter_data.seq_lens))
query_lens = flatten_2d_lists( query_lens = []
[inter_data.query_lens for inter_data in self.inter_data_list]) for inter_data in self.inter_data_list:
query_lens.extend(inter_data.query_lens)
# Mapping from request IDs to sequence IDs. Used for Jamba models # Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself. # that manages the cache by itself.
request_ids_to_seq_ids = { request_ids_to_seq_ids = {
...@@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
batch_size = graph_batch_size batch_size = graph_batch_size
# Tokens and positions. # Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size) if cuda_graph_pad_size:
input_positions.extend([0] * cuda_graph_pad_size) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
...@@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.pin_memory) self.runner.pin_memory)
# Sequence and query lengths. # Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size) if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
...@@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
flatten_2d_lists(inter_data.lora_index_mapping) flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list for inter_data in self.inter_data_list
]) ])
lora_index_mapping.extend([0] * cuda_graph_pad_size) if cuda_graph_pad_size:
lora_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
lora_prompt_mapping = flatten_2d_lists([ lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping) flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list for inter_data in self.inter_data_list
]) ])
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping, **dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping, prompt_mapping=lora_prompt_mapping,
...@@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.prompt_adapter_index_mapping inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list for inter_data in self.inter_data_list
]) ])
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) if cuda_graph_pad_size:
prompt_adapter_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
prompt_adapter_prompt_mapping = flatten_2d_lists([ prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list for inter_data in self.inter_data_list
...@@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set_cpu_offload_max_bytes( set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3)) int(self.cache_config.cpu_offload_gb * 1024**3))
# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
...@@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata) builder.add_seq_group(seq_group_metadata)
builder.reset_cached_inter_data()
return builder.build() # type: ignore return builder.build() # type: ignore
@torch.inference_mode() @torch.inference_mode()
...@@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens, seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory, model_input.query_lens, self.device, self.pin_memory,
generators) generators, self.sampling_metadata_cache)
else: else:
sampling_metadata = None sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt is_prompt = (seq_group_metadata_list[0].is_prompt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment