Unverified Commit c99db8c8 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove V0 core (#25321)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 72dd1595
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import heapq
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed Blocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_id: int) -> bool:
pass
@abstractmethod
def evict(self) -> Tuple[int, int]:
"""Runs the eviction algorithm and returns the evicted block's
content hash along with physical block id along with physical block id
"""
pass
@abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata"""
pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@property
@abstractmethod
def num_blocks(self) -> int:
pass
class BlockMetaData:
"""Data structure for storing key data describe cached block, so that
evictor could use to make its decision which one to choose for eviction
Here we use physical block id as the dict key, as there maybe several
blocks with the same content hash, but their physical id is unique.
"""
def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chosen arbitrarily
"""
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority
# queue relative to the free table size. When this threshold is exceeded,
# a cleanup operation is triggered to reduce memory usage.
CLEANUP_THRESHOLD = 50
def __init__(self):
self.free_table: Dict[int, BlockMetaData] = {}
self.priority_queue = []
def __contains__(self, block_id: int) -> bool:
return block_id in self.free_table
def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
while self.priority_queue:
# We do not remove outdated entries from the priority queue at the
# time of updating the last_accessed timestamp. Instead, outdated
# entries are filtered out here during eviction. Outdated entries
# would either not in the free table, or have older last accessed
# time.
last_accessed, _, block_id, content_hash = heapq.heappop(
self.priority_queue)
if (block_id in self.free_table and
self.free_table[block_id].last_accessed == last_accessed):
self.free_table.pop(block_id)
return block_id, content_hash
raise ValueError("No usable cache memory left")
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens,
last_accessed)
heapq.heappush(
self.priority_queue,
(last_accessed, -num_hashed_tokens, block_id, content_hash))
self._cleanup_if_necessary()
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
def _cleanup_if_necessary(self):
if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len(
self.free_table):
self._cleanup()
def _cleanup(self):
new_priority_queue: List[Tuple[float, int, int, int]] = []
for block_id, block in self.free_table.items():
new_priority_queue.append(
(block.last_accessed, -block.num_hashed_tokens, block_id,
block.content_hash))
heapq.heapify(new_priority_queue)
self.priority_queue = new_priority_queue
def remove(self, block_id: int):
if block_id not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
self.free_table.pop(block_id)
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager(ABC):
@staticmethod
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "selfattn":
from vllm.core.block_manager import SelfAttnBlockSpaceManager
return SelfAttnBlockSpaceManager
if version == "placeholder":
from vllm.core.placeholder_block_space_manager import (
PlaceholderBlockSpaceManager)
return PlaceholderBlockSpaceManager
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
pass
@abstractmethod
def allocate(self, seq_group: SequenceGroup) -> None:
pass
@abstractmethod
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
pass
@abstractmethod
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
pass
@abstractmethod
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
pass
@abstractmethod
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
pass
@abstractmethod
def free(self, seq: Sequence) -> None:
pass
@abstractmethod
def get_block_table(self, seq: Sequence) -> List[int]:
pass
@abstractmethod
def get_num_free_gpu_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_cpu_blocks(self) -> int:
pass
@abstractmethod
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for specified or all devices."""
pass
@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass
@abstractmethod
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
pass
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class PlaceholderBlockSpaceManager(BlockSpaceManager):
"""A version of BlockSpaceManager for use in environments
where block management is not required.
For example: pooling models or attention-free models like Mamba.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
actions. It's designed to be used in scenarios where the overhead of
block management is unnecessary, such as in an embedding environment.
"""
def __init__(
self,
**kwargs,
) -> None:
pass
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK
def allocate(self, seq_group: SequenceGroup) -> None:
# No actual allocation logic needed
pass
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return True
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
return []
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.OK
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
return None # type: ignore
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return True
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
return None # type: ignore
def free(self, seq: Sequence) -> None:
# No operation on free
return
def get_block_table(self, seq: Sequence) -> List[int]:
return None # type: ignore
def get_num_free_gpu_blocks(self) -> int:
return 1
def get_num_free_cpu_blocks(self) -> int:
return 1
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
def get_common_computed_block_ids(self,
seq_group: List[Sequence]) -> List[int]:
return []
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return True
def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
return
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import os
import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig
from vllm.config.lora import LoRAConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
SequenceStatus)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
@dataclass
class SchedulingBudget:
"""The available slots for scheduling.
TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
budget update from the same request_id. It is because in normal scheduling
path, we update RUNNING num_seqs ahead of time, meaning it could be
updated more than once when scheduling RUNNING requests. Since this won't
happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default.
"""
token_budget: int
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
# Number of cached tokens in the batch.
_num_cached_tokens: int = 0
# Number of actual non-cached tokens in the batch.
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
# We allow num_new_tokens to be 0 when the entire sequence has
# been cached.
assert num_new_tokens >= 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self,
req_id: str,
num_batched_tokens: int,
num_cached_tokens: int = 0):
if req_id in self._request_ids_num_batched_tokens:
return
assert num_cached_tokens >= 0
assert num_batched_tokens >= 0
self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens
self._num_cached_tokens += num_cached_tokens
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
if req_id in self._request_ids_num_batched_tokens:
self._request_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens
def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._request_ids_num_curr_seqs:
return
self._request_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._request_ids_num_curr_seqs:
self._request_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs
@property
def num_batched_tokens(self):
return self._num_batched_tokens
@property
def num_curr_seqs(self):
return self._num_curr_seqs
@property
def num_cached_tokens(self):
return self._num_cached_tokens
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
seq_group: SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
@dataclass
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
num_batched_tokens: int
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]]
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
preempted: int
def __post_init__(self):
# Swap in and swap out should never happen at the same time.
assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self):
assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
def key_fn(group: ScheduledSequenceGroup):
key = (group.seq_group.lora_int_id, group.seq_group.request_id)
if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
# Sort sequence groups so that all prefills come before all
# decodes as required by chunked prefill.
return (not group.seq_group.is_prefill(), *key)
return key
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
key=key_fn)
@property
def lora_requests(self) -> Set[LoRARequest]:
return {
g.seq_group.lora_request
for g in self.scheduled_seq_groups
if g.seq_group.lora_request is not None
}
@dataclass
class SchedulerRunningOutputs:
"""The requests that are scheduled from a running queue.
Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked.
prefill_seq_groups: List[ScheduledSequenceGroup]
# The preempted sequences.
preempted: List[SequenceGroup]
# Sequences that are swapped out.
swapped_out: List[SequenceGroup]
# The blocks to swap out.
blocks_to_swap_out: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Optimization for fast-access to seq_group lists
decode_seq_groups_list: List[SequenceGroup]
prefill_seq_groups_list: List[SequenceGroup]
@classmethod
def create_empty(cls) -> "SchedulerRunningOutputs":
return SchedulerRunningOutputs(
decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
decode_seq_groups_list=[],
prefill_seq_groups_list=[],
)
@dataclass
class SchedulerSwappedInOutputs:
"""The requests that are scheduled from a swap queue.
Could contain prefill (prefill that's chunked) or decodes.
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[ScheduledSequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Infeasible sequence groups.
infeasible_seq_groups: List[SequenceGroup]
@classmethod
def create_empty(cls) -> "SchedulerSwappedInOutputs":
return SchedulerSwappedInOutputs(
decode_seq_groups=[],
prefill_seq_groups=[],
blocks_to_swap_in=[],
blocks_to_copy=[],
num_lookahead_slots=0,
infeasible_seq_groups=[],
)
@dataclass
class SchedulerPrefillOutputs:
"""The requests that are scheduled from a waiting queue.
Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
)
def seq_group_metadata_builder():
return SequenceGroupMetadata(request_id="",
is_prompt=False,
seq_data={},
sampling_params=None,
block_tables={})
def scheduler_running_outputs_builder():
return SchedulerRunningOutputs(decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
prefill_seq_groups_list=[],
decode_seq_groups_list=[])
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
@dataclass
class PartialPrefillMetadata:
"""Holds information about the partial prefills that are currently running
during a single iteration of the Scheduler.
When chunked prefill is enabled, we allow a certain number of seqs to be
partially prefilled during each iteration. Having multiple partial prefills
in flight allows us to minimize TTFT and avoid decode starvation in cases
where a single sequence group with a very large prompt blocks the queue for
too many iterations.
The number of long prefill requests is limited so that smaller
requests may jump the queue in front of them and get to the decode
phase faster.
"""
# A minimum bound on the total number of prefills to be scheduled during
# this iteration
schedulable_prefills: int
# The number of long prefill requests currently running
long_prefills: int
scheduler_config: SchedulerConfig
def can_schedule(self, seq_group: SequenceGroup) -> bool:
"""When concurrent partial prefills are enabled,
we limit the number of long requests and only accept
shorter requests from the queue while running them
concurrently"""
return not (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold
and self.long_prefills
>= self.scheduler_config.max_long_partial_prefills
and self.scheduler_config.max_num_partial_prefills > 1)
def maybe_increment_partial_prefills(self,
seq_group: SequenceGroup) -> None:
# When a new prefill is scheduled, we need to know if it is a
# long request
if (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold):
self.long_prefills += 1
@classmethod
def from_queues(
cls,
running: Deque[SequenceGroup],
waiting: Deque[SequenceGroup],
scheduler_config: SchedulerConfig,
) -> "PartialPrefillMetadata":
"""Create a PartialPrefillMetadata object from the current state of
the scheduler's queues.
This accounts for the currently running prefill requests, and peeks into
the waiting queue to see if there are more prefills to potentially be
scheduled during this iteration."""
prefills = 0
long_prefills = 0
waiting_long_prefills = 0
for sg in running:
if sg.first_seq.data.stage == SequenceStage.PREFILL:
prefills += 1
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
long_prefills += 1
for sg in waiting:
# Don't bother looping through the rest of the queue if we know
# there are already at
# least max_partial_prefills requests to fill
if prefills >= scheduler_config.max_num_partial_prefills:
break
# Don't count long requests from the waiting queue if we aren't
# going to schedule them anyway
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
if (long_prefills + waiting_long_prefills
>= scheduler_config.max_long_partial_prefills):
continue
waiting_long_prefills += 1
prefills += 1
# NB: long_prefills and waiting_long_prefills are tracked separately.
# We don't account for the waiting requests here because we need to use
# this metadata to track how many have actually been scheduled.
return PartialPrefillMetadata(
schedulable_prefills=min(
prefills, scheduler_config.max_num_partial_prefills),
long_prefills=long_prefills,
scheduler_config=scheduler_config,
)
class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
version = "selfattn"
if (self.scheduler_config.runner_type == "pooling"
or self.cache_config.is_attention_free):
version = "placeholder"
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching,
)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
# Contain decode requests.
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
self.prev_prompt = False
# Latency of the last prompt step
self.last_prompt_latency = 0.0
# preemption mode, RECOMPUTE or SWAP
self.user_specified_preemption_mode = scheduler_config.preemption_mode
# The following field is test-only. It is used to inject artificial
# preemption.
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._seq_group_metadata_cache: List[PyObjectCache] = []
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: List[PyObjectCache] = []
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback = output_proc_callback
self.use_async_output_proc = self.output_proc_callback is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
# List with the chunk sizes to hand out to each sequence depending
# on how many partial prefills are running. This is slightly faster than
# running an integer division every time a prefill is scheduled.
# This splits the budget evenly among all prefills.
self.partial_prefill_budget_lookup_list = [0] * (
self.scheduler_config.max_num_partial_prefills + 1)
self.partial_prefill_budget_lookup_list[0] = (
scheduler_config.max_num_batched_tokens)
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
self.partial_prefill_budget_lookup_list[i] = (
scheduler_config.max_num_batched_tokens // i)
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
@property
def num_decoding_tokens_per_seq(self) -> int:
"""The number of new tokens."""
return 1
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
seq_id_to_seq_group: helper for groups with n>1
"""
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
# When n>1, seq_group.request_id looks like
# foo_parallel_sample_0, while request_ids is just foo, and we
# should resolve it as real_request_id to match.
if seq_group.request_id in seq_id_to_seq_group:
real_request_id = seq_id_to_seq_group[
seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
# We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]
self._free_seq_group_cross_attn_blocks(aborted_group)
def _free_seq_group_cross_attn_blocks(
self,
seq_group: SequenceGroup,
) -> None:
"""
Free a sequence group from a cross-attention block table.
Has no effect on decoder-only models.
"""
if seq_group.is_encoder_decoder():
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool:
return (len(self.waiting) != 0 or len(self.running) != 0
or len(self.swapped) != 0)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_manager.reset_prefix_cache(device)
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
ret.prefill_seq_groups.clear()
ret.preempted.clear()
ret.swapped_out.clear()
ret.num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=False, enable_chunking=enable_chunking)
ret.decode_seq_groups_list.clear()
ret.prefill_seq_groups_list.clear()
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
prefill_seq_groups: List[
ScheduledSequenceGroup] = ret.prefill_seq_groups
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
# We discard the cached tokens info here because we don't need it
# for running sequence:
# 1. If a sequence is running with chunked prefill, the cached
# tokens info was already used for the first prefill.
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group,
SequenceStatus.RUNNING,
enable_chunking,
budget,
partial_prefill_metadata,
)
num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0:
# No budget => Stop
break
running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if (self.use_async_output_proc and seq_group.seqs[0].get_len()
> self.scheduler_config.max_model_len):
self._async_stopped.append(seq_group)
continue
# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group, enable_chunking):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if (curr_loras is not None and seq_group.lora_int_id > 0
and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id)
# Determine victim sequence
cont_loop = True
if running_queue:
# Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False
# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
if not cont_loop:
break
else:
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = (
self._scheduled_seq_group_cache[
self.cache_id].get_object())
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
prefill_seq_groups.append(scheduled_seq_group)
ret.prefill_seq_groups_list.append(seq_group)
else:
scheduled_seq_group.token_chunk_size = 1
decode_seq_groups.append(scheduled_seq_group)
ret.decode_seq_groups_list.append(seq_group)
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
# OPTIMIZATION: Note that get_max_num_running_seqs is
# expensive. For the default scheduling chase where
# enable_chunking is False, num_seqs are updated before running
# this method, so we don't have to update it again here.
if enable_chunking:
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret
def _schedule_swapped(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop.
is_prefill = seq_group.is_prefill()
alloc_status = self.block_manager.can_swap_in(
seq_group,
self._get_num_lookahead_slots(is_prefill, enable_chunking))
if alloc_status == AllocStatus.LATER:
break
elif alloc_status == AllocStatus.NEVER:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id,
)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
swapped_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
swapped_queue.popleft()
continue
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.SWAPPED, enable_chunking,
budget))
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.SWAPPED)
break
if lora_int_id > 0 and curr_loras is not None:
curr_loras.add(lora_int_id)
swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
if is_prefill:
prefill_seq_groups.append(
ScheduledSequenceGroup(
seq_group,
token_chunk_size=num_new_tokens_uncached +
num_new_tokens_cached,
))
else:
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
swapped_queue.extendleft(leftover_swapped)
return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False, enable_chunking=enable_chunking),
infeasible_seq_groups=infeasible_seq_groups,
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled:
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens,
)
# Model is fine tuned with long context. Return the fine tuned max_len.
if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len
else:
return prompt_limit
def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
"""Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
Returns:
The priority of the sequence group.
"""
return seq_group.priority, seq_group.arrival_time
def _schedule_priority_preemption(
self,
budget: SchedulingBudget,
) -> int:
"""Sorts waiting and running queue. Also, force preempt requests
from the running queue if their priority is lower.
Priority-based preemption is used with the priority policy.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
Returns:
A count of priority-based preemptions.
"""
waiting_queue = self.waiting
running_queue = deque(sorted(self.running, key=self._get_priority))
blocks_to_swap_out: List[Tuple[int, int]] = []
force_preemption_count = 0
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens_uncached, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget)
# Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
# Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK
and budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
)):
break
# Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget))
budget.subtract_num_batched_tokens(
vseq_group.request_id, num_running_tokens_uncached)
num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
# Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
self.waiting = waiting_queue
self.running = running_queue
return force_preemption_count
def _schedule_prefills(
self,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
as a new prefill (that starts from beginning -> most recently generated
tokens).
It schedules waiting requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are scheduled.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerPrefillOutputs.
"""
if budget.remaining_token_budget() == 0:
# Do nothing: Can't add any more prefill anyway
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
using_prompt_embeds: bool = False
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
if (partial_prefill_metadata is not None
and not partial_prefill_metadata.can_schedule(seq_group)):
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group,
SequenceStatus.WAITING,
enable_chunking,
budget,
partial_prefill_metadata=partial_prefill_metadata,
))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d",
num_new_tokens,
prompt_limit,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.FINISHED_IGNORED)
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
num_lookahead_slots: int = 0
# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(
seq_group, num_lookahead_slots=num_lookahead_slots)
if can_allocate == AllocStatus.LATER:
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
"Input prompt (%d tokens) + lookahead slots (%d) is "
"too long and exceeds the capacity of block_manager",
num_new_tokens,
num_lookahead_slots,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.FINISHED_IGNORED)
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
# We cannot mix sequence groups that use prompt embeds and
# those that do not.
if len(seq_groups) == 0:
using_prompt_embeds = seq_group.uses_prompt_embeds()
if using_prompt_embeds != seq_group.uses_prompt_embeds():
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
if (budget.num_batched_tokens
>= self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
num_new_seqs = seq_group.get_max_num_running_seqs()
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
# Can schedule this request.
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
if partial_prefill_metadata is not None:
partial_prefill_metadata.maybe_increment_partial_prefills(
seq_group)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
budget.add_num_batched_tokens(
seq_group.request_id,
num_batched_tokens=num_new_tokens_uncached,
num_cached_tokens=num_new_tokens_cached,
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
self.prev_prompt = True
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
)
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted.
"""
# Include running requests to the budget.
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
if len(prefills.seq_groups
) == 0 and self.scheduler_config.policy == "priority":
self._schedule_priority_preemption(budget)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out) == 0):
swapped_in = \
self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(running_scheduled.decode_seq_groups_list)
if len(swapped_in.decode_seq_groups) > 0:
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
preempted = len(running_scheduled.preempted) + len(
running_scheduled.swapped_out)
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
assert len(running_scheduled.prefill_seq_groups) == 0
assert len(swapped_in.prefill_seq_groups) == 0
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
ignored_seq_groups_for_embeds = list[SequenceGroup]()
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
ignored_seq_groups_for_embeds.clear()
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
if len(scheduled_seq_groups) > 0:
using_prompt_embeds = scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
ignored_seq_groups_for_embeds.clear()
indices_ignored = list[int]()
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
if using_prompt_embeds !=\
schedule_seq_group.seq_group.uses_prompt_embeds():
ignored_seq_groups_for_embeds.append(
schedule_seq_group.seq_group)
indices_ignored.append(i)
if len(ignored_seq_groups_for_embeds) > 0:
scheduled_seq_groups = [
group for i, group in enumerate(scheduled_seq_groups)
if i not in indices_ignored
]
else:
ignored_seq_groups_for_embeds.clear()
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=preempted,
)
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
finished. 3. schedule swapped request. 4. schedule new prefill
requests.
The policy can sustain the high GPU utilization because it can put
prefill and decodes requests to the same batch, while it improves
inter token latency because decodes requests don't need to be blocked
by prefill requests.
"""
budget = SchedulingBudget(
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras: Set[int] = set()
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Create partial prefill metadata
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
running=self.running,
waiting=self.waiting,
scheduler_config=self.scheduler_config,
)
# Decoding should be always scheduled first by fcfs.
running_scheduled = self._schedule_running(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
prefills = self._schedule_prefills(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
# Because multiple prefills may be running concurrently, we need to
# make sure that prefills which are scheduled to finish are listed
# before those that won't. This is so that on the next scheduling
# iteration when they have transitioned to the decode stage, they are
# properly prioritized over sequences that are still in the prefill
# stage.
self.running.extend(
self._order_finishing_prefills_first(
running_scheduled.prefill_seq_groups))
self.running.extend([s.seq_group for s in prefills.seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
# Put prefills first due to Attention backend ordering assumption.
scheduled_seq_groups = (prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups)
num_prefill_groups = (len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups))
return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=0,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
)
def _order_finishing_prefills_first(
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
) -> List[SequenceGroup]:
"""Returns a list of prefilling SequenceGroups where sequences that are
scheduled to finish prefilling are listed first"""
finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
]
not_finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
]
return finishing + not_finishing
def _schedule(self) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
return self._schedule_chunked_prefill()
else:
return self._schedule_default()
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
# It is True only for testing case to trigger artificial preemption.
if (self.enable_artificial_preemption
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
and self.artificial_preempt_cnt > 0):
self.artificial_preempt_cnt -= 1
return False
is_prefill = seq_group.is_prefill()
num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill, enable_chunking)
return self.block_manager.can_append_slots(
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# async_output_proc is allowed only when we have a single sequence
# in the sequence group
no_single_seq = seq_group.sampling_params is None or (
seq_group.sampling_params.n == 1)
return no_single_seq
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs: SchedulerOutputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
allow_async_output_proc: bool = self.use_async_output_proc
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache[
self.cache_id].get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
encoder_seq_data = encoder_seq.data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
seq_group)
else:
encoder_seq_data = None
cross_block_table = None
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
if self.cache_config.enable_prefix_caching:
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True
is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + num_computed_tokens
< seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=(seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups
> 0 else None),
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
)
else:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta_and_reset()
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata)
if allow_async_output_proc:
allow_async_output_proc = self._allow_async_output_proc(
seq_group)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)
self._seq_group_metadata_cache[self.next_cache_id].reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
# component in the e2e latency.
for seq_group in self.running:
if seq_group is not None and seq_group.metrics is not None:
if seq_group.metrics.scheduler_time is not None:
seq_group.metrics.scheduler_time += scheduler_time
else:
seq_group.metrics.scheduler_time = scheduler_time
# Move to next cache (if exists)
self.cache_id = self.next_cache_id
# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)
def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def remove_seq_from_computed_blocks_tracker(
self, seq_group: SequenceGroup,
status: Optional[SequenceStatus]) -> None:
seqs = seq_group.get_seqs(status=status)
for seq in seqs:
self._remove_seq_from_computed_blocks_tracker(seq)
def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
"""
Free a sequence computed blocks tracker _seq_id_to_blocks_hashes
and _seq_id_to_num_tokens_computed.
"""
self.block_manager.remove_seq_from_computed_blocks_tracker(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
self._free_finished_seq_group(seq_group)
if not seq_group.is_finished():
remaining.append(seq_group)
self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False,
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
ints, the first int is the source block index, and the second
int is the destination block index. This list is updated with
the new source and destination block indices for the appended
slots.
enable_chunking (bool): True if chunked prefill is enabled.
"""
is_prefill: bool = seq_group.is_prefill()
num_lookahead_slots: int = self._get_num_lookahead_slots(
is_prefill, enable_chunking)
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
for seq in seq_group.get_seqs(status=seq_status):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
if len(cows) > 0:
blocks_to_copy.extend(cows)
def _preempt(self, seq_group: SequenceGroup,
blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if self.user_specified_preemption_mode is None:
if seq_group.get_max_num_running_seqs() == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
elif self.user_specified_preemption_mode == "swap":
preemption_mode = PreemptionMode.SWAP
else:
preemption_mode = PreemptionMode.RECOMPUTE
if self.num_cumulative_preemption % 50 == 0:
logger.warning(
"Sequence group %s is preempted by %s mode because there is "
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d",
seq_group.request_id,
preemption_mode,
self.num_cumulative_preemption + 1,
)
self.num_cumulative_preemption += 1
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
raise AssertionError("Invalid preemption mode.")
return preemption_mode
def _preempt_by_recompute(
self,
seq_group: SequenceGroup,
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
assert len(seqs) == 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
self.free_seq(seq)
seq.reset_state_for_recompute()
self._free_seq_group_cross_attn_blocks(seq_group)
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: List[Tuple[int, int]],
) -> None:
self._swap_out(seq_group, blocks_to_swap_out)
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: List[Tuple[int, int]],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: List[Tuple[int, int]],
) -> None:
if not self.block_manager.can_swap_out(seq_group):
# FIXME(woosuk): Abort the sequence group instead of aborting the
# entire engine.
raise RuntimeError(
"Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error.")
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
def _passed_delay(self, now: float) -> bool:
if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time
self.prev_time, self.prev_prompt = now, False
# Delay scheduling prompts to let waiting queue fill up
if self.scheduler_config.delay_factor > 0 and self.waiting:
earliest_arrival_time = min(
[e.metrics.arrival_time for e in self.waiting])
passed_delay = ((now - earliest_arrival_time)
> (self.scheduler_config.delay_factor *
self.last_prompt_latency) or not self.running)
else:
passed_delay = True
return passed_delay
def _get_num_lookahead_slots(self, is_prefill: bool,
enable_chunking: bool) -> int:
"""The number of slots to allocate per sequence per step, beyond known
token ids. Speculative decoding uses these slots to store KV activations
of tokens which may or may not be accepted.
"""
return 0
def _get_num_new_uncached_and_cached_tokens(
self,
seq_group: SequenceGroup,
status: SequenceStatus,
enable_chunking: bool,
budget: SchedulingBudget,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> Tuple[int, int]:
"""
Returns the number of new uncached and cached tokens to schedule for a
given sequence group that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
Returns (0, 0) if the new token cannot be computed due to token budget.
The cached tokens's blocks are already computed, and the attention
backend will reuse the cached blocks rather than recomputing them. So
the scheduler could schedule these cached tokens "for free".
Args:
seq_group: The sequence group to get the number of new tokens to
schedule.
status: The status of the sequences to get the number of new tokens
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
A tuple of two ints. The first int is the number of new uncached
tokens to schedule. The second int is the number of cached tokens.
If no more new tokens can be scheduled, returns (0, 0).
"""
num_cached_new_tokens = 0
num_uncached_new_tokens = 0
seqs = seq_group.get_seqs(status=status)
# Compute the number of new uncached and cached tokens for
# each sequence.
for seq in seqs:
if not seq.is_prefill():
# Decode sequences should always just have 1 uncached token
# TODO(rickyx): Actually is this still correct for multi-step?
num_uncached_new_tokens += 1
continue
num_computed_tokens_seq = seq.get_num_computed_tokens()
all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
if not self.cache_config.enable_prefix_caching:
# If prefix caching is not enabled, all new tokens are uncached.
num_uncached_new_tokens += all_num_new_tokens_seq
continue
# NOTE: the cache token might be currently in a block that's in an
# evictor meaning that it's not yet allocated. However, we don't
# exclude such tokens in the cache count because it will be
# guaranteed to be allocated later if the sequence can be allocated.
num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
seq)
# Sanity check.
if num_cached_tokens_seq < num_computed_tokens_seq:
# This should only happen with chunked prefill, and
# the seq is still in prefill. The `num_cached_tokens_seq`
# is the value we calculated on scheduling the first prefill.
# For subsequent continuous prefill steps, we cached the
# number of cache tokens for the sequence so the cached token
# count could be less than the number of computed tokens.
# See comments on `ComputedBlocksTracker` for more details.
assert (
seq.is_prefill() and seq.status == SequenceStatus.RUNNING
and self.scheduler_config.chunked_prefill_enabled
), ("Number of cached tokens should not be less than the "
"number of computed tokens for a sequence that's still "
f"in prefill. But there are {num_cached_tokens_seq} cached "
f"tokens and {num_computed_tokens_seq} computed tokens "
f"for sequence {seq.seq_id}.")
num_cached_new_tokens_seq = max(
0, num_cached_tokens_seq - num_computed_tokens_seq)
num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
num_cached_new_tokens_seq)
num_uncached_new_tokens += num_uncached_new_tokens_seq
num_cached_new_tokens += num_cached_new_tokens_seq
if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
# For a fully cached hit sequence, we actually need to recompute the
# last token. So we need at least 1 uncached token to schedule.
# See ModelRunner._compute_for_prefix_cache_hit for more details.
num_uncached_new_tokens = 1
num_cached_new_tokens -= 1
if enable_chunking and len(seqs) == 1:
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
self.scheduler_config,
self.cache_config,
budget,
self._get_prompt_limit(seq_group),
num_uncached_new_tokens,
self.partial_prefill_budget_lookup_list,
partial_prefill_metadata,
)
return num_uncached_new_tokens, num_cached_new_tokens
@staticmethod
def _chunk_new_tokens_to_schedule(
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
budget: SchedulingBudget,
prompt_limit: int,
num_new_tokens: int,
partial_prefill_budget_lookup_list: List[int],
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> int:
"""
Chunks the number of new tokens to schedule based on the budget when
chunked prefill is enabled.
Args:
scheduler_config: The scheduler config.
cache_config: The cache config.
budget: The budget to chunk the number of tokens to compute.
prompt_limit: The maximum number of tokens allowed in a prompt.
num_new_tokens: The number of new tokens to schedule.
Returns:
The number of new tokens to schedule after chunking.
"""
remaining_token_budget = budget.remaining_token_budget()
# Get the number of tokens to allocate to this prefill slot
prefill_slot_budget = (
remaining_token_budget if partial_prefill_metadata is None else
partial_prefill_budget_lookup_list[
partial_prefill_metadata.schedulable_prefills])
if cache_config.enable_prefix_caching:
# When prefix caching is enabled and we're partially prefilling
# a sequence, we always allocate a number of new tokens that is
# divisible by the block size to avoid partial block matching.
block_size = cache_config.block_size
# Don't exceed either the total budget or slot budget.
# Take min of those and get the next lowest multiple of the
# block size:
remaining_token_budget = (
min(remaining_token_budget, prefill_slot_budget) //
block_size) * block_size
# NB: In the case where num_new_tokens < budget, we are
# finishing prefill for this sequence, so we do not need to
# allocate a full block.
num_new_tokens = min(num_new_tokens, remaining_token_budget,
prefill_slot_budget)
return num_new_tokens
......@@ -7,13 +7,11 @@ from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.pooling_params import PoolingParams
......@@ -266,11 +264,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
async def do_log_stats(self) -> None:
...
@abstractmethod
......
......@@ -601,11 +601,7 @@ class AsyncLLM(EngineClient):
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None
async def do_log_stats(
self,
scheduler_outputs=None,
model_output=None,
) -> None:
async def do_log_stats(self) -> None:
if self.logger_manager:
self.logger_manager.log()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CacheEngine class for managing the KV cache."""
from typing import List
import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available)
logger = init_logger(__name__)
class CacheEngine:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
) -> None:
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.device_config = device_config
self.head_size = model_config.get_head_size()
# Models like Jamba, have mixed typed layers, E.g Mamba
self.num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(self.head_size,
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free,
use_mla=model_config.use_mla)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
for i in kv_cache_stride_order)
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(
kv_cache_allocation_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device).permute(*kv_cache_stride_order)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache)
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
key_cache_entry = num_heads * head_size
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
total = num_attention_layers * cache_config.block_size * \
(key_cache_entry + value_cache_entry)
dtype_size = get_dtype_size(dtype)
return dtype_size * total
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import gc
import inspect
import itertools
import time
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Type, TypeVar, Union)
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from tqdm.auto import tqdm
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import (supports_lora, supports_mrope,
supports_multimodal)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo,
weak_ref_tensor)
from vllm.worker.model_runner_base import (
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
_NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
# For now, bump up cache limits for recompilations during CUDA graph warmups.
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128
@dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
"""
This base class contains metadata needed for the base model forward pass
but not metadata for possible additional steps, e.g., sampling. Model
runners that run additional steps should subclass this method to add
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
async_callback: Optional[Callable] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
previous_hidden_states: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForGPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForGPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
# Exclude `async_callback` to be able to pickle this object
def __getstate__(self):
state = self.__dict__.copy()
del state["async_callback"]
return state
# TODO: What happens when we depickle this object?
# How can we update this callback to properly pass it to the engine?
def __setstate__(self, state):
self.__dict__.update(state)
self.__dict__.update({'async_callback': None})
@dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForGPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
self.prompt_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore
def __init__(
self,
*,
# From sequence group metadata.
request_id: str,
seq_ids: List[int],
is_prompt: bool,
block_tables: Optional[Dict[int, List[int]]],
computed_block_nums: List[int],
n_seqs: int = 0,
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# This is used in the dual-chunk flash attention backend.
prompt_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
context_lens: Optional[List[int]] = None,
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,
# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,
# Multi-modal inputs.
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
reinit: bool = False,
reinit_use_defaults: bool = False,
encoder_seq_len: int = 0,
):
if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore
for i, seq_id in enumerate(seq_ids):
self.seq_ids[i] = seq_id # type: ignore
else:
self.seq_ids = seq_ids
self.request_id = request_id
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.encoder_seq_len = encoder_seq_len
if reinit:
if len(self.seq_ids) == 1 and reinit_use_defaults:
self.simple_reinit()
else:
if input_tokens:
self.input_tokens = input_tokens
else:
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions:
self.input_positions = input_positions
else:
for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear()
self.mrope_input_positions = None
if seq_lens:
self.seq_lens = seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.seq_lens[seq_id] = 0
if orig_seq_lens:
self.orig_seq_lens = orig_seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0
if prompt_lens:
self.prompt_lens = prompt_lens
else:
for seq_id in range(len(self.seq_ids)):
self.prompt_lens[seq_id] = 0
if query_lens:
self.query_lens = query_lens
else:
for seq_id in range(len(self.seq_ids)):
self.query_lens[seq_id] = 0
if context_lens:
self.context_lens = context_lens
else:
for seq_id in range(len(self.seq_ids)):
self.context_lens[seq_id] = 0
if curr_sliding_window_blocks:
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks
else:
for seq_id in range(len(self.seq_ids)):
self.curr_sliding_window_blocks[seq_id] = 0
if lora_index_mapping:
self.lora_index_mapping = lora_index_mapping
else:
self.lora_index_mapping.clear()
if lora_prompt_mapping:
self.lora_prompt_mapping = lora_prompt_mapping
else:
self.lora_prompt_mapping.clear()
if lora_requests:
self.lora_requests = lora_requests
else:
self.lora_requests.clear()
else:
self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or []
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.prompt_lens = prompt_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
self.n_seqs = len(self.seq_ids)
if not reinit:
self.__post_init__()
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
self.prompt_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
seq_ids=[0] * num_seqs,
is_prompt=True,
block_tables=None,
computed_block_nums=[])
def init_cached_inter_data(self, *args, **kwargs):
assert len(args) == 0
assert "seq_ids" in kwargs
seq_ids = kwargs["seq_ids"]
num_seqs = len(seq_ids)
# The inter-data cache is per model_runner
inter_data_cache = self.runner.inter_data_cache
if num_seqs not in inter_data_cache:
inter_data_cache[num_seqs] = PyObjectCache(
self.gen_inter_data_builder(num_seqs))
obj = inter_data_cache[num_seqs].get_object()
obj.__init__(*args, **kwargs)
return obj
def reset_cached_inter_data(self):
for cache in self.runner.inter_data_cache.values():
cache.reset()
def __init__(self,
runner: "GPUModelRunnerBase",
finished_requests_ids: Optional[List[str]] = None):
super().__init__()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self.per_seq_compute_fns = [
self._compute_lens,
self._compute_for_prefix_cache_hit,
self._compute_for_sliding_window,
self._compute_lora_input,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self.per_seq_group_compute_fns = [
self._compute_multi_modal_input,
]
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.scheduler_config = self.runner.scheduler_config
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
# Attention metadata inputs.
if self.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Engine/Model configurations.
self.chunked_prefill_enabled = (
self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled)
if self.sliding_window is not None:
self.sliding_window_blocks = (
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids
# if the current batch is decode-only.
# will be set to False if there is any non-decode request.
self.decode_only = True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
self.attn_metadata_builder.prepare()
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
token_chunk_size = seq_group_metadata.token_chunk_size
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.model_config.is_encoder_decoder:
context_len = seq_len - 1
else:
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
if seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None:
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
inter_data.mrope_input_positions[
seq_idx] = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
def _compute_for_prefix_cache_hit(
self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums = inter_data.computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if not prefix_cache_hit:
return
assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len)
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:]
context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block = 0
sliding_seq_len = inter_data.seq_lens[seq_idx]
if not inter_data.is_prompt and self.sliding_window is not None:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block = self.sliding_window_blocks
# number of elements in last block
suff_len = inter_data.seq_lens[seq_idx] % self.block_size
sliding_seq_len = min(inter_data.seq_lens[seq_idx],
self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_block += 1
inter_data.curr_sliding_window_blocks[
seq_idx] = curr_sliding_window_block
inter_data.seq_lens[seq_idx] = sliding_seq_len
def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if not self.enable_lora:
return
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
positions = inter_data.input_positions[0]
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata,
range(positions[0], positions[0] + len(positions)))
# M-RoPE requires mrope_positions even for plain text; return early
# when mm_kwargs is empty only if inter_data.is_prompt is False.
if not mm_kwargs and not inter_data.is_prompt:
return
inter_data.multi_modal_kwargs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas.
if self.runner.model_config.uses_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
None)
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_config = self.runner.model_config.hf_config
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
for seq_idx in range(inter_data.n_seqs):
seq_data = seq_group_metadata.seq_data[
inter_data.seq_ids[seq_idx]]
token_ids = seq_data.get_token_ids()
if supports_mrope(self.runner.model):
mrope_input_positions, mrope_position_delta = \
self.runner.model.get_mrope_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
mrope_input_positions = mrope_input_positions.tolist()
else:
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
seq_data.mrope_position_delta = mrope_position_delta
inter_data.mrope_input_positions[
seq_idx] = mrope_input_positions
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder."""
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt
if is_prompt:
assert n_seqs == 1
self.decode_only = False
encoder_seq_len = 0
if self.runner.model_config.is_encoder_decoder:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
inter_data = self.init_cached_inter_data(
request_id=seq_group_metadata.request_id,
seq_ids=seq_ids,
is_prompt=is_prompt,
block_tables=seq_group_metadata.block_tables,
computed_block_nums=seq_group_metadata.computed_block_nums,
reinit=True,
reinit_use_defaults=True,
encoder_seq_len=encoder_seq_len)
self.inter_data_list.append(inter_data)
for seq_idx in range(n_seqs):
for per_seq_fn in self.per_seq_compute_fns:
per_seq_fn(inter_data, seq_idx, seq_group_metadata)
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)
def _use_captured_graph(self,
batch_size: int,
decode_only: bool,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (decode_only and not self.runner.model_config.enforce_eager
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)
def _get_cuda_graph_pad_size(self,
num_seqs: int,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> int:
"""
Determine the number of padding sequences required for running in
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
In the multi-step + chunked-prefill case, only the first step
has Prefills (if any). The rest of the steps are guaranteed to be all
decodes. In this case, we set up the padding as if all the sequences
are decodes so we may run all steps except the first step in CUDA graph
mode.
Args:
num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using
CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs.
Returns:
int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1.
"""
decode_only = self.decode_only
if not decode_only:
# Early exit so we can treat num_seqs as the batch_size below.
return -1
# batch_size out of this function refers to the number of input
# tokens being scheduled. This conflation of num_seqs as batch_size
# is valid as this is a decode-only case.
batch_size = num_seqs
if not self._use_captured_graph(batch_size, decode_only,
max_decode_seq_len,
max_encoder_seq_len):
return -1
graph_batch_size = self.runner.vllm_config.pad_for_cudagraph(
batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = list[int]()
inputs_embeds_list = list[torch.Tensor]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
if inter_data.inputs_embeds is not None:
inputs_embeds_list.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_list) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
mrope_input_positions: Optional[List[List[int]]] = None
if any(inter_data.mrope_input_positions is not None
for inter_data in self.inter_data_list):
mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
for inter_data in self.inter_data_list:
msections = inter_data.mrope_input_positions
if msections is None:
for _seq_input_positions in inter_data.input_positions:
mrope_input_positions[idx].extend(
_seq_input_positions)
else:
for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
_seq_mrope_input_positions[idx])
input_positions = None
else:
input_positions = []
for inter_data in self.inter_data_list:
for cur_input_positions in inter_data.input_positions:
input_positions.extend(cur_input_positions)
seq_lens = []
query_lens = []
max_decode_seq_len = 0
max_encoder_seq_len = 0
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
query_lens.extend(inter_data.query_lens)
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder:
max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len)
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids = {
data.request_id: data.seq_ids
for data in self.inter_data_list
}
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens),
max_decode_seq_len=max_decode_seq_len,
max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens)
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size
# Tokens and positions.
if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
if mrope_input_positions is not None:
for idx in range(3):
mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
# Sequence and query lengths.
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(r for data in self.inter_data_list
for r in data.lora_requests)
lora_index_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
if cuda_graph_pad_size:
lora_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only))
# Multi-modal data.
multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list
if data.multi_modal_kwargs is not None
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size
#
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.has_inner_state = model_config.has_inner_state
self.in_profile_run = False
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max seq len to capture / block size).
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
self.cross_layer_shared_graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
# used for speculative decoding to avoid a divide-by-zero in
# model_config.get_head_size()
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
# Lazy initialization
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}
# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceGroupToSample object. In Pipeline-Parallel, we have
# more than 1 Scheduler, resulting in a potential back-to-back
# prepare_model_inputs() call. This clobbers the cached
# SequenceGroupToSample objects, as we reset the cache during
# every prepare_model_inputs() call.
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler(self.device) as m:
time_before_load = time.perf_counter()
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
self.lora_manager = LRUCacheWorkerLoRAManager(
self.vllm_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
)
self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load)
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend(
self.vllm_config)
compilation_counter.dynamo_as_is_count += 1
self.model = torch.compile(self.model,
fullgraph=True,
backend=backend)
def get_model(self) -> nn.Module:
return self.model
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
from vllm.model_executor.model_loader import TensorizerLoader
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
try:
self.builder.add_seq_group(seq_group_metadata)
except Exception as e:
# Raise an exception that tracks the ID of the bad request
raise InputProcessingError(seq_group_metadata.request_id,
str(e)) from e
self.builder.reset_cached_inter_data()
return self.builder.build() # type: ignore
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
@torch.inference_mode()
def profile_run(self) -> None:
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs)
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
assert num_loras > 0
assert self.lora_manager is not None
dummy_lora_requests: list[LoRARequest] = []
with self.lora_manager.dummy_lora_cache():
for idx in range(num_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
return dummy_lora_requests
def _remove_dummy_loras(self):
# Remove dummy loras.
assert self.lora_manager is not None
self.remove_all_loras()
def _dummy_run(self,
max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None:
with self.set_in_profile_run():
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = \
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
dummy_lora_requests = self._add_dummy_loras(
self.lora_config.max_loras)
assert len(dummy_lora_requests) == self.lora_config.max_loras
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the
# total number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding,
# which needs to be accounted for when calculating the GPU blocks
# for vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.
multi_modal_placeholders,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = \
self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
# Disable KV Scale Calculation for dummy data during profile run
if model_input.attn_metadata is not None:
model_input.attn_metadata.enable_kv_scales_calculation = False
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
if self.lora_config:
self._remove_dummy_loras()
return
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int:
"""Cuda graph capture a model and return cudagraph memory
consumption in bytes.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
requires fixed sized tensors, supporting large/variable batch
size requires high GPU memory overhead. Thus, vLLM only captures
decoding requests. Mixed batch (chunked prefill + decoding) or
prefill requests are not captured.
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
assert not self.model_config.enforce_eager
logger.info("Capturing cudagraphs for decoding. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI. "
"If out-of-memory error occurs during cudagraph capture,"
" consider decreasing `gpu_memory_utilization` or "
"switching to eager mode. You can also reduce the "
"`max_num_seqs` as needed to decrease memory usage.")
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
inputs_embeds = torch.zeros(
(max_batch_size, self.model_config.get_hidden_size()),
dtype=self.model_config.dtype,
device=self.device)
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states = None
if "previous_hidden_states" in inspect.signature(
self.model.forward).parameters:
previous_hidden_states = torch.empty(
[max_batch_size,
self.model_config.get_hidden_size()],
dtype=self.model_config.dtype,
device=self.device)
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
batch_size=max_batch_size,
dtype=self.model_config.dtype,
device=self.device)
dummy_lora_id: Optional[int] = None
dummy_lora_request: LoRARequest = []
if self.lora_config:
# The goal is to capture the LoRA kernels in cuda graphs.
# for this purpose, as single dummy lora is sufficient.
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
assert len(dummy_lora_requests) == 1
dummy_lora_request = dummy_lora_requests[0]
dummy_lora_id = dummy_lora_request.lora_int_id
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
# We need to not only iterate over batch sizes, but also whether
# to use inputs_embeds or not, hence we use the cartesian
# product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = ((
True, False) if self.model_config.enable_prompt_embeds else
(False, ))
compilation_cases = itertools.product(
cudagraph_capture_sizes,
cudagraph_inputs_embeds,
)
# Only rank 0 should print progress bar during capture
if get_tensor_model_parallel_rank() == 0:
compilation_cases = tqdm(
list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=[dummy_lora_id] * batch_size,
prompt_mapping=[dummy_lora_id] * batch_size,
is_prefill=False))
self.set_active_loras(set([dummy_lora_request]),
lora_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size),
self.model_config.is_encoder_decoder)
capture_inputs = {
"input_ids":
input_tokens[:batch_size],
"inputs_embeds":
inputs_embeds[:batch_size]
if use_inputs_embeds else None,
"positions":
input_positions[..., :batch_size],
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
"kv_caches":
kv_caches[virtual_engine],
"attn_metadata":
attn_metadata,
"memory_pool":
self.graph_memory_pool,
"stream":
graph_capture_context.stream
}
if previous_hidden_states is not None:
capture_inputs[
"previous_hidden_states"] = previous_hidden_states[:
batch_size]
if self.has_inner_state:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
"seqlen_agnostic_capture_inputs":
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
if self.model_config.is_encoder_decoder:
# add the additional inputs to capture for
# encoder-decoder models.
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][(
batch_size, use_inputs_embeds)] = graph_runner
if self.lora_config:
self._remove_dummy_loras()
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / GiB_bytes)
return cuda_graph_size
def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str,
Any]):
"""
Updates the set of input tensors needed for CUDA graph capture in an
encoder-decoder model.
This method modifies the provided `capture_inputs` dictionary by
adding tensors specific to encoder-decoder specific models that
need to be captured for CUDA Graph replay.
"""
# During the decode phase encoder_input_ids and encoder_positions are
# unset. Do the same thing for graph capture.
capture_inputs["encoder_input_ids"] = torch.tensor([],
dtype=torch.long,
device=self.device)
capture_inputs["encoder_positions"] = torch.tensor([],
dtype=torch.long,
device=self.device)
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
GPU model runner with sampling step.
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForGPUWithSamplingMetadata:
model_input = \
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
return model_input
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators, self.sampling_metadata_cache)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
use_inputs_embeds = model_input.inputs_embeds is not None
model_executable = self.graph_runners[virtual_engine][(
graph_batch_size, use_inputs_embeds)]
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
model_executable = self.model
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
# model forwarding
# In KV cache database setting, it will change the model input so that
# we can skip prefilling on tokens that successfully received KV caches
# NOTE: The receive operation is blocking
bypass_model_exec = False
if self.need_recv_kv(model_input, kv_caches):
hidden_or_intermediate_states, bypass_model_exec, model_input = \
get_kv_transfer_group().recv_kv_caches_and_hidden_states(
# model is used to know which layer the current worker
# is working on, so that we can receive KV for only those
# layers.
model_executable,
model_input,
kv_caches=kv_caches
)
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if self.is_driver_worker:
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the
# latency from the start time of the driver worker to the end
# time of the driver worker. The model forward time will then
# end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled_token_ids = []
valid_outputs = []
for sequence_group_output in output.outputs:
if len(sequence_group_output.samples) == 0:
continue
assert len(sequence_group_output.samples) == 1
valid_outputs.append(sequence_group_output)
sampled_token_ids.append(
sequence_group_output.samples[0].output_token)
sampled_token_ids = torch.tensor(sampled_token_ids).to(
self.device)
sampled_token_ids = broadcast_tensor_dict(
{"sampled_token_ids":
sampled_token_ids})["sampled_token_ids"]
else:
sampled_token_ids = broadcast_tensor_dict(
)["sampled_token_ids"]
if len(sampled_token_ids) > 0:
sampled_token_embeds = \
self.model.get_input_embeddings(sampled_token_ids)
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
for i, sequence_group_output in enumerate(valid_outputs):
sequence_group_output.samples[0].output_embed = \
sampled_token_embeds[i]
if not self.is_driver_worker:
return []
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif decode_meta.use_cuda_graph:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states
return [output]
def need_recv_kv(self, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor]) -> bool:
"""Check if we need to receive kv-cache from the other worker.
We need to receive KV when
1. current vLLM instance is KV cache consumer/decode vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
if self.vllm_config.kv_transfer_config is None:
return False
if model_input.attn_metadata is None:
raise ValueError("model_input.attn_metadata cannot be None")
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
return self.vllm_config.kv_transfer_config.is_kv_consumer and (
not is_profile_run) and is_prefill_run
def need_send_kv(self, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor]) -> bool:
"""Check if we need to send kv-cache to the other worker.
We need to send KV when
1. current vLLM instance is KV cache producer/prefill vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
if self.vllm_config.kv_transfer_config is None:
return False
if model_input.attn_metadata is None:
raise ValueError("model_input.attn_metadata cannot be None")
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
return self.vllm_config.kv_transfer_config.is_kv_producer and (
not is_profile_run) and is_prefill_run
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph
class CUDAGraphRunner(nn.Module):
def __init__(self, model: nn.Module, backend_name: str,
attn_state: AttentionState, is_encoder_decoder_model: bool):
super().__init__()
self.model = model
self.backend_name = backend_name
self.attn_state = attn_state
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
self._is_encoder_decoder_model = is_encoder_decoder_model
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
):
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# Note one iteration is not enough for torch.compile
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
# Wait for the warm up operations to finish before proceeding with
# Graph Capture.
torch.cuda.synchronize()
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
hidden_or_intermediate_states = weak_ref_tensor(
output_hidden_or_intermediate_states)
elif isinstance(output_hidden_or_intermediate_states,
IntermediateTensors):
hidden_or_intermediate_states = IntermediateTensors(
tensors={
key: weak_ref_tensor(value)
for key, value in
output_hidden_or_intermediate_states.tensors.items()
})
del output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids":
input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
"positions":
positions,
"kv_caches":
kv_caches,
**self.attn_state.get_graph_input_buffers(
attn_metadata, self._is_encoder_decoder_model),
**kwargs,
}
if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
self.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
self.output_buffers = hidden_or_intermediate_states
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
) -> torch.Tensor:
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
if positions is not None:
# in some case like MLA, it will reuse positions in metadata
# but truncate them to the original size
# so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True)
if inputs_embeds is not None:
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
inputs_embeds, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)
self.attn_state.prepare_graph_input_buffers(
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)
if "previous_hidden_states" in self.input_buffers:
self.input_buffers["previous_hidden_states"].copy_(
kwargs["previous_hidden_states"], non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
if key != "model_execute_time" and key != "model_forward_time":
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
if self._is_encoder_decoder_model:
self.input_buffers["encoder_input_ids"].copy_(
kwargs['encoder_input_ids'], non_blocking=True)
self.input_buffers["encoder_positions"].copy_(
kwargs['encoder_positions'], non_blocking=True)
# Run the graph.
self.graph.replay()
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
return self.output_buffers
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import gc
import os
from contextlib import nullcontext
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
memory_profiling)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ("medusa",
"mlp_speculator",
"eagle",
"deepseek_mtp",
"glm4_moe_mtp",
"mimo_mtp",
"ernie_mtp",
"qwen3_next_mtp")) \
else {"return_hidden_states": True}
self.model_runner: GPUModelRunnerBase = ModelRunner(
vllm_config=self.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
# Buffers saved before sleep
self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
# only print profiler results on rank 0
if self.local_rank == 0:
print(self.profiler.key_averages().table(
sort_by="self_cuda_time_total"))
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.baseline_snapshot = MemorySnapshot()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
context = nullcontext()
with context:
self.model_runner.load_model()
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
def determine_available_kv_cache_memory(self,
total_gpu_memory: int) -> float:
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
GiB = lambda b: b / GiB_bytes
msg = (
f"Initial free memory "
f"{GiB(self.baseline_snapshot.free_memory):.2f} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly.")
logger.info(msg)
return self.cache_config.kv_cache_memory_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self.non_torch_memory = result.non_torch_increase
self.peak_activation_memory = result.torch_peak_increase
self._assert_memory_footprint_increased_during_profiling()
self.requested_memory = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
self.available_kv_cache_memory = (self.requested_memory -
result.non_kv_cache_memory)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return self.available_kv_cache_memory
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
available_kv_cache_memory = self.determine_available_kv_cache_memory(
total_gpu_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Final cleanup
gc.collect()
return num_gpu_blocks, num_cpu_blocks
def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, total = torch.cuda.mem_get_info()
cuda_memory = total - free_gpu_memory
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
"Error in memory profiling. "
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
f"currently used memory {cuda_memory}. "
f"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
context = nullcontext()
with context:
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
shared_kv_cache_layers: dict[str, str] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
shared_kv_cache_layers[layer_name] = kv_tgt_layer
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache, shared_kv_cache_layers)
def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager:
cuda_graph_memory_bytes = self.model_runner.capture_model(
self.gpu_cache)
if (self.cache_config.kv_cache_memory_bytes is None
and hasattr(self, "peak_activation_memory")):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
non_kv_cache_memory = (self.model_runner.model_memory_usage +
self.peak_activation_memory +
self.non_torch_memory +
cuda_graph_memory_bytes)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
kv_cache_memory_bytes_to_gpu_limit = (
self.baseline_snapshot.free_memory - non_kv_cache_memory -
redundancy_buffer_memory)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory) - non_kv_cache_memory -
redundancy_buffer_memory)
msg = (
f"Free memory on device "
f"({GiB(self.baseline_snapshot.free_memory)}/"
f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
f"requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory)} bytes.")
logger.info(msg)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]],
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list = []
for metadata_or_delta in seq_group_metadata_list:
request_id = metadata_or_delta.request_id
if request_id not in self._seq_group_metadata_cache:
# The first prefill.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[request_id] = metadata_or_delta
else:
# The first prefill is already cached.
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
self._seq_group_metadata_cache[request_id].apply_delta(
metadata_or_delta)
else:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[
request_id] = metadata_or_delta
new_seq_group_metadata_list.append(
self._seq_group_metadata_cache[request_id])
# Clean up finished ids
for finished_id in finished_request_ids:
del self._seq_group_metadata_cache[finished_id]
return new_seq_group_metadata_list
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Optional[List[SamplerOutput]]:
if execute_model_req is not None:
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
execute_model_req.seq_group_metadata_list,
execute_model_req.finished_requests_ids)
execute_model_req.seq_group_metadata_list = (
new_seq_group_metadata_list)
output = super()._execute_model_spmd(execute_model_req,
intermediate_tensors)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return CacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank,
current_platform.dist_backend)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
max_model_len, pipeline_parallel_size) -> None:
if is_attention_free and num_gpu_blocks != 0:
raise ValueError("No memory should be allocated for the cache blocks "
f"for an attention-free model, but {num_gpu_blocks} "
"blocks are allocated.")
if not is_attention_free and num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if not is_attention_free and max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment