Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
...@@ -7,7 +8,7 @@ from typing import Sequence as GenericSequence ...@@ -7,7 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set from typing import Set
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
...@@ -46,6 +47,10 @@ class BlockAllocatorBase(ABC): ...@@ -46,6 +47,10 @@ class BlockAllocatorBase(ABC):
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
pass pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod @abstractmethod
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
pass pass
...@@ -130,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase): ...@@ -130,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
return (self.num_blocks - self.current_num_blocks + return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks) self.evictor.num_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor return block_hash in self.cached_blocks or block_hash in self.evictor
...@@ -189,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase): ...@@ -189,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
return len(self.free_blocks) return len(self.free_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError( raise NotImplementedError(
"Invalid codepath for uncached block allocator.") "Invalid codepath for uncached block allocator.")
...@@ -220,9 +231,9 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -220,9 +231,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.block_sliding_window = None self.block_sliding_window = None
if sliding_window is not None: if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window, # Round up to nearest block size to regularize sliding window
block_size) # allocation sizes.
self.block_sliding_window = sliding_window // block_size self.block_sliding_window = math.ceil(sliding_window / block_size)
self.watermark = watermark self.watermark = watermark
assert watermark >= 0.0 assert watermark >= 0.0
...@@ -390,7 +401,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -390,7 +401,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table.append(block_table[len(block_table) % block_table.append(block_table[len(block_table) %
self.block_sliding_window]) self.block_sliding_window])
else: else:
# The sequence has a new logical block. # The sequence hash a new logical block.
# Allocate a new physical block. # Allocate a new physical block.
new_block = self._allocate_last_physical_block(seq) new_block = self._allocate_last_physical_block(seq)
block_table.append(new_block) block_table.append(new_block)
...@@ -443,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -443,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def can_swap_in(self, def can_swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool: num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0 assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation" ), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group) blocks = self._get_physical_blocks(seq_group)
...@@ -453,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -453,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# at least one free block right after the swap-in. # at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot(). # NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
return AllocStatus.NEVER
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def swap_in(self, def swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
......
...@@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.watermark = watermark self.watermark = watermark
assert watermark >= 0.0 assert watermark >= 0.0
assert not enable_caching, "Prefix caching not yet supported"
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create( self.block_allocator = CpuGpuBlockAllocator.create(
# Currently, only naive blocks are supported (no prefix caching). allocator_type="prefix_caching" if enable_caching else "naive",
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
block_size=block_size, block_size=block_size,
...@@ -192,19 +190,30 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -192,19 +190,30 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert seq.seq_id in self.block_tables assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids) assert all(b is not None for b in block_ids)
return block_ids return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq, now): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# TODO add prefix caching support. # Update the last accessed time of all the blocks accessed
# Tracked here https://github.com/vllm-project/vllm/issues/3667 # in this step.
pass # And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching:
block_table = self.block_tables[seq.seq_id]
block_ids = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# We ignore the sequence group as its not necessary. After the batch is # The only need for mark block as computed is for prefix caching,
# formed by the scheduler, we do not need to mark blocks from individual # while currently we could determine whether one block is computed
# sequence groups as computed -- all blocks in the batch can be marked # or not by check whether it has content hash.
# as computed. # So this function is useless for block_v2.
self.block_allocator.mark_blocks_as_computed() pass
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]: self, seqs: List[Sequence]) -> GenericSequence[int]:
...@@ -220,16 +229,17 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -220,16 +229,17 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq_block_ids = [ seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
] ]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids( return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork() self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool: num_lookahead_slots: int) -> AllocStatus:
return False return AllocStatus.LATER
def swap_in(self, seq_group: SequenceGroup, def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]: num_lookahead_slots: int) -> Dict[int, int]:
......
import enum
from abc import ABC, abstractmethod, abstractproperty
from typing import OrderedDict, Tuple
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_id: int) -> bool:
pass
@abstractmethod
def evict(self) -> Tuple[int, int]:
"""Runs the eviction algorithm and returns the evicted block's
content hash along with physical block id along with physical block id
"""
pass
@abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata"""
pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class BlockMetaData():
"""Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction
Here we use physical block id as the dict key, as there maybe several
blocks with the same content hash, but their physical id is unique.
"""
def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
def __contains__(self, block_id: int) -> bool:
return block_id in self.free_table
def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block.last_accessed > block.last_accessed or (
evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens,
last_accessed)
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int):
if block_id not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
self.free_table.pop(block_id)
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
...@@ -63,7 +63,7 @@ class BlockSpaceManager(ABC): ...@@ -63,7 +63,7 @@ class BlockSpaceManager(ABC):
@abstractmethod @abstractmethod
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool: num_lookahead_slots: int) -> AllocStatus:
pass pass
@abstractmethod @abstractmethod
......
import enum import enum
import os
import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
...@@ -15,6 +17,13 @@ from vllm.utils import merge_dicts ...@@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
logger = init_logger(__name__) logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
"""Preemption modes. """Preemption modes.
...@@ -119,6 +128,8 @@ class SchedulerOutputs: ...@@ -119,6 +128,8 @@ class SchedulerOutputs:
ignored_seq_groups: List[SequenceGroup] ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
def __post_init__(self): def __post_init__(self):
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
...@@ -201,6 +212,8 @@ class SchedulerSwappedInOutputs: ...@@ -201,6 +212,8 @@ class SchedulerSwappedInOutputs:
blocks_to_copy: Dict[int, List[int]] blocks_to_copy: Dict[int, List[int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# Infeasible sequence groups.
infeasible_seq_groups: List[SequenceGroup]
@classmethod @classmethod
def create_empty(cls) -> "SchedulerSwappedInOutputs": def create_empty(cls) -> "SchedulerSwappedInOutputs":
...@@ -210,6 +223,7 @@ class SchedulerSwappedInOutputs: ...@@ -210,6 +223,7 @@ class SchedulerSwappedInOutputs:
blocks_to_swap_in={}, blocks_to_swap_in={},
blocks_to_copy={}, blocks_to_copy={},
num_lookahead_slots=0, num_lookahead_slots=0,
infeasible_seq_groups=[],
) )
...@@ -286,6 +300,13 @@ class Scheduler: ...@@ -286,6 +300,13 @@ class Scheduler:
# Latency of the last prompt step # Latency of the last prompt step
self.last_prompt_latency = 0.0 self.last_prompt_latency = 0.0
# The following field is test-only. It is used to inject artificial
# preemption.
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
@property @property
def lora_enabled(self) -> bool: def lora_enabled(self) -> bool:
return bool(self.lora_config) return bool(self.lora_config)
...@@ -320,7 +341,7 @@ class Scheduler: ...@@ -320,7 +341,7 @@ class Scheduler:
for seq_group in state_queue: for seq_group in state_queue:
if not request_ids: if not request_ids:
# Using 'break' here may add two extra iterations, # Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity . # but is acceptable to reduce complexity.
break break
if seq_group.request_id in request_ids: if seq_group.request_id in request_ids:
# Appending aborted group into pending list. # Appending aborted group into pending list.
...@@ -386,15 +407,13 @@ class Scheduler: ...@@ -386,15 +407,13 @@ class Scheduler:
# groups to preempt. # groups to preempt.
now = time.time() now = time.time()
running_queue = policy.sort_by_priority(now, running_queue) running_queue = policy.sort_by_priority(now, running_queue)
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget) seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
# We can have up to 1 running prefill at any given time in running if num_running_tokens == 0:
# queue, which means we can guarantee chunk size is at least 1. break
assert num_running_tokens != 0
running_queue.popleft() running_queue.popleft()
while not self._can_append_slots(seq_group): while not self._can_append_slots(seq_group):
...@@ -449,9 +468,6 @@ class Scheduler: ...@@ -449,9 +468,6 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
# Make sure all queues are updated.
assert len(running_queue) == 0
return running_queue, SchedulerRunningOutputs( return running_queue, SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
...@@ -500,14 +516,26 @@ class Scheduler: ...@@ -500,14 +516,26 @@ class Scheduler:
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time() now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue) swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = []
leftover_swapped: Deque[SequenceGroup] = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop. # If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group): alloc_status = self.block_manager.can_swap_in(seq_group)
if alloc_status == AllocStatus.LATER:
break break
elif alloc_status == AllocStatus.NEVER:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
swapped_queue.popleft()
continue
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
...@@ -545,7 +573,6 @@ class Scheduler: ...@@ -545,7 +573,6 @@ class Scheduler:
ScheduledSequenceGroup(seq_group, ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens)) token_chunk_size=num_new_tokens))
else: else:
assert num_new_tokens == 1
decode_seq_groups.append( decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1)) ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
...@@ -559,7 +586,9 @@ class Scheduler: ...@@ -559,7 +586,9 @@ class Scheduler:
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False)) is_prefill=False),
infeasible_seq_groups=infeasible_seq_groups,
)
def _schedule_prefills( def _schedule_prefills(
self, self,
...@@ -617,8 +646,9 @@ class Scheduler: ...@@ -617,8 +646,9 @@ class Scheduler:
if num_new_tokens > self.prompt_limit: if num_new_tokens > self.prompt_limit:
logger.warning( logger.warning(
f"Input prompt ({num_new_tokens} tokens) is too long" "Input prompt (%d tokens) is too long"
f" and exceeds limit of {self.prompt_limit}") " and exceeds limit of %d", num_new_tokens,
self.prompt_limit)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -631,8 +661,9 @@ class Scheduler: ...@@ -631,8 +661,9 @@ class Scheduler:
break break
elif can_allocate == AllocStatus.NEVER: elif can_allocate == AllocStatus.NEVER:
logger.warning( logger.warning(
f"Input prompt ({num_new_tokens} tokens) is too long" "Input prompt (%d tokens) is too long"
f" and exceeds the capacity of block_manager") " and exceeds the capacity of block_manager",
num_new_tokens)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -765,8 +796,10 @@ class Scheduler: ...@@ -765,8 +796,10 @@ class Scheduler:
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
) )
def _schedule_chunked_prefill(self): def _schedule_chunked_prefill(self):
...@@ -853,6 +886,7 @@ class Scheduler: ...@@ -853,6 +886,7 @@ class Scheduler:
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
) )
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:
...@@ -866,6 +900,13 @@ class Scheduler: ...@@ -866,6 +900,13 @@ class Scheduler:
"""Determine whether or not we have enough space in the KV cache to """Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group. continue generation of the sequence group.
""" """
# It is True only for testing case to trigger artificial preemption.
if (self.enable_artificial_preemption
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
and self.artificial_preempt_cnt > 0):
self.artificial_preempt_cnt -= 1
return False
# Appending slots only occurs in decoding. # Appending slots only occurs in decoding.
is_prefill = False is_prefill = False
...@@ -874,15 +915,6 @@ class Scheduler: ...@@ -874,15 +915,6 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
) )
def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
# Swapping in is considered decode.
is_prefill = False
return self.block_manager.can_swap_in(
seq_group=seq_group,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
...@@ -913,6 +945,20 @@ class Scheduler: ...@@ -913,6 +945,20 @@ class Scheduler:
self.block_manager.get_common_computed_block_ids( self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING))) seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True
if seq_group.is_prefill():
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = seq_group.is_prefill() is_prompt = seq_group.is_prefill()
...@@ -922,6 +968,7 @@ class Scheduler: ...@@ -922,6 +968,7 @@ class Scheduler:
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
do_sample=do_sample,
token_chunk_size=token_chunk_size, token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
...@@ -1099,11 +1146,14 @@ class Scheduler: ...@@ -1099,11 +1146,14 @@ class Scheduler:
if `enable_chunking` is True. If a sequence group has multiple if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen. phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
""" """
num_new_tokens = 0 num_new_tokens = 0
seqs = seq_group.get_seqs(status=status) seqs = seq_group.get_seqs(status=status)
for seq in seqs: for seq in seqs:
num_new_tokens += seq.get_num_new_tokens() num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in. # Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a # If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case. # decode phase. Do not chunk in that case.
......
...@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from .parallel_state import (get_tensor_model_parallel_group, from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce) is_pynccl_enabled_for_all_reduce)
...@@ -33,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: ...@@ -33,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
if out is not None: if out is not None:
return out return out
if is_pynccl_enabled_for_all_reduce(): if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
pynccl_utils.all_reduce(input_) pynccl_utils.all_reduce(input_)
else: else:
torch.distributed.all_reduce(input_, torch.distributed.all_reduce(input_,
...@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any], ...@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append((key, TensorMetadata(value.dtype,
value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict( def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
...@@ -161,27 +194,20 @@ def broadcast_tensor_dict( ...@@ -161,27 +194,20 @@ def broadcast_tensor_dict(
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
for key, value in tensor_dict.items(): metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
if isinstance(value, torch.Tensor): # `metadata_list` lives in CPU memory.
assert value.is_cuda, ( # `broadcast_object_list` involves serialization and deserialization,
f"Tensor {key}: {value} is not on cuda. Currently we only " # all happening on CPU. Therefore, we can use the CPU group.
f"support broadcasting tensors on cuda.")
metadata_list.append(
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=metadata_group)
async_handles = [] async_handles = []
for key, value in metadata_list: for tensor in tensor_list:
if isinstance(value, TensorMetadata): async_handles.append(
tensor = tensor_dict[key] torch.distributed.broadcast(tensor,
async_handles.append( src=src,
torch.distributed.broadcast(tensor, group=group,
src=src, async_op=True))
group=group,
async_op=True))
for async_handle in async_handles: for async_handle in async_handles:
async_handle.wait() async_handle.wait()
...@@ -189,7 +215,7 @@ def broadcast_tensor_dict( ...@@ -189,7 +215,7 @@ def broadcast_tensor_dict(
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=metadata_group)
assert recv_metadata_list[0] is not None assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
......
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Optional from typing import Any, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
try: try:
...@@ -37,7 +37,7 @@ def init_custom_ar() -> None: ...@@ -37,7 +37,7 @@ def init_custom_ar() -> None:
return return
if world_size not in _SUPPORTED_WORLD_SIZES: if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warn( logger.warning(
"Custom allreduce is disabled due to an unsupported world size: " "Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify" "%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size, " disable_custom_all_reduce=True explicitly.", world_size,
...@@ -47,22 +47,22 @@ def init_custom_ar() -> None: ...@@ -47,22 +47,22 @@ def init_custom_ar() -> None:
# note: num dev can be larger than world_size if we're only using # note: num dev can be larger than world_size if we're only using
# first few GPUs # first few GPUs
if num_dev < world_size: if num_dev < world_size:
logger.warn( logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the " "Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.") " is set.")
return return
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ: cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
device_ids = list( if cuda_visible_devices:
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(num_dev)) device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink # this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids) full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warn( logger.warning(
"Custom allreduce is disabled because it's not supported on more" "Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify" " than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
...@@ -71,7 +71,7 @@ def init_custom_ar() -> None: ...@@ -71,7 +71,7 @@ def init_custom_ar() -> None:
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
if not _can_p2p(rank, world_size): if not _can_p2p(rank, world_size):
logger.warn( logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P" "Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify" " capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
......
...@@ -43,15 +43,16 @@ try: ...@@ -43,15 +43,16 @@ try:
nccl = ctypes.CDLL(so_file) nccl = ctypes.CDLL(so_file)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to load NCCL library from {so_file} ." "Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs." "It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted " "Otherwise, the nccl library might not exist, be corrupted "
f"or it does not support the current platform {platform.platform()}." "or it does not support the current platform %s."
f"One solution is to download libnccl2 version 2.18 from " "One solution is to download libnccl2 version 2.18 from "
f"https://developer.download.nvidia.com/compute/cuda/repos/ " "https://developer.download.nvidia.com/compute/cuda/repos/ "
f"and extract the libnccl.so.2 file. If you already have the " "and extract the libnccl.so.2 file. If you already have the "
f"library, please set the environment variable VLLM_NCCL_SO_PATH" "library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.") " to point to the correct nccl library path.", so_file,
platform.platform())
raise e raise e
# === export types and functions from nccl to Python === # === export types and functions from nccl to Python ===
...@@ -199,6 +200,10 @@ _c_ncclAllReduce.argtypes = [ ...@@ -199,6 +200,10 @@ _c_ncclAllReduce.argtypes = [
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
] ]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration: # equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm); # ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy _c_ncclCommDestroy = nccl.ncclCommDestroy
...@@ -227,6 +232,7 @@ class NCCLCommunicator: ...@@ -227,6 +232,7 @@ class NCCLCommunicator:
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.") "NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group) self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group) self.world_size = dist.get_world_size(group)
if self.rank == 0: if self.rank == 0:
...@@ -234,7 +240,9 @@ class NCCLCommunicator: ...@@ -234,7 +240,9 @@ class NCCLCommunicator:
else: else:
self.unique_id = NcclUniqueId() self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)) tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=0, group=group) ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist() byte_list = tensor.tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
...@@ -250,15 +258,13 @@ class NCCLCommunicator: ...@@ -250,15 +258,13 @@ class NCCLCommunicator:
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
# nccl communicator and stream will use this device # nccl communicator and stream will use this device
current_device = torch.cuda.current_device() # `torch.cuda.device` is a context manager that changes the
try: # current cuda device to the specified one
torch.cuda.set_device(device) with torch.cuda.device(device):
NCCL_CHECK( NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)) self.unique_id, self.rank))
self.stream = torch.cuda.Stream() self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self, def all_reduce(self,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -279,11 +285,3 @@ class NCCLCommunicator: ...@@ -279,11 +285,3 @@ class NCCLCommunicator:
ncclDataTypeEnum.from_torch(tensor.dtype), ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))) ctypes.c_void_p(stream.cuda_stream)))
def __del__(self):
# `dist` module might have been already destroyed
if hasattr(dist, 'destroy_process_group'):
dist.destroy_process_group()
# function might have been already destroyed
if _c_ncclCommDestroy is not None:
_c_ncclCommDestroy(self.comm)
...@@ -14,7 +14,7 @@ try: ...@@ -14,7 +14,7 @@ try:
except Exception as e: except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module # in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs # e.g. when running on machines with AMD GPUs
logger.info(f"Failed to import NCCL library: {e}") logger.info("Failed to import NCCL library: %s", e)
logger.info("It is expected if you are not running on NVIDIA GPUs.") logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass pass
...@@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream): ...@@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
def init_process_group(group: Optional[ProcessGroup] = None) -> None: def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized() assert not is_initialized()
global comm global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}") logger.info("vLLM is using nccl==%s", ncclGetVersion())
comm = NCCLCommunicator(group=group) comm = NCCLCommunicator(group=group)
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib import contextlib
import os
from typing import Optional from typing import Optional
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
...@@ -57,8 +58,10 @@ def init_distributed_environment( ...@@ -57,8 +58,10 @@ def init_distributed_environment(
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
): ):
logger.debug(f"{world_size=} {rank=} {local_rank=} " logger.debug(
f"{distributed_init_method=} {backend=}") "world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
assert distributed_init_method is not None, ( assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing " "distributed_init_method must be provided when initializing "
...@@ -78,7 +81,7 @@ def init_distributed_environment( ...@@ -78,7 +81,7 @@ def init_distributed_environment(
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://": if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK']) local_rank = envs.LOCAL_RANK
global _LOCAL_RANK global _LOCAL_RANK
_LOCAL_RANK = local_rank _LOCAL_RANK = local_rank
...@@ -130,15 +133,17 @@ def initialize_model_parallel( ...@@ -130,15 +133,17 @@ def initialize_model_parallel(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP global _TP_DEVICE_GROUP, _TP_CPU_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized") "tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size) (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks, backend=backend) group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks: if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group _TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
...@@ -183,7 +188,7 @@ def ensure_model_parallel_initialized( ...@@ -183,7 +188,7 @@ def ensure_model_parallel_initialized(
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized.""" """Check if tensor and pipeline parallel groups are initialized."""
return (_TENSOR_MODEL_PARALLEL_GROUP is not None return (_TP_DEVICE_GROUP is not None
and _PIPELINE_MODEL_PARALLEL_GROUP is not None) and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
...@@ -195,9 +200,16 @@ def get_cpu_world_group(): ...@@ -195,9 +200,16 @@ def get_cpu_world_group():
def get_tensor_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( assert _TP_DEVICE_GROUP is not None, (
"tensor model parallel group is not initialized") "tensor model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP return _TP_DEVICE_GROUP
def get_tensor_model_parallel_cpu_group():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert _TP_CPU_GROUP is not None, (
"tensor model parallel cpu group is not initialized")
return _TP_CPU_GROUP
def get_pipeline_model_parallel_group(): def get_pipeline_model_parallel_group():
...@@ -275,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -275,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none and destroy them.""" """Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP global _TP_DEVICE_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP: if _TP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None _TP_DEVICE_GROUP = None
global _TP_CPU_GROUP
if _TP_CPU_GROUP:
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
_TP_CPU_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP: if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
......
...@@ -9,6 +9,7 @@ from typing import Dict, Optional, Sequence ...@@ -9,6 +9,7 @@ from typing import Dict, Optional, Sequence
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank from .parallel_state import get_cpu_world_group, get_local_rank
...@@ -102,17 +103,19 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -102,17 +103,19 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count() num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
path = os.path.expanduser( path = os.path.expanduser(
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \ if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)): and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}") logger.info("generating GPU P2P access cache for in %s", path)
cache = {} cache = {}
for _i in range(num_dev): for _i in range(num_dev):
for _j in range(num_dev): for _j in range(num_dev):
...@@ -126,7 +129,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -126,7 +129,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
if is_distributed: if is_distributed:
cpu_world_group = get_cpu_world_group() cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group) dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}") logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path, "r") as f:
cache = json.load(f) cache = json.load(f)
_gpu_p2p_access_cache = cache _gpu_p2p_access_cache = cache
......
import argparse import argparse
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import List, Optional, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
...@@ -11,10 +11,17 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS ...@@ -11,10 +11,17 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
served_model_name: Optional[Union[List[str]]] = None
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto' tokenizer_mode: str = 'auto'
...@@ -44,7 +51,8 @@ class EngineArgs: ...@@ -44,7 +51,8 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
max_context_len_to_capture: int = 8192 max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0 tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray" tokenizer_pool_type: str = "ray"
...@@ -52,6 +60,7 @@ class EngineArgs: ...@@ -52,6 +60,7 @@ class EngineArgs:
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
...@@ -74,6 +83,8 @@ class EngineArgs: ...@@ -74,6 +83,8 @@ class EngineArgs:
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -92,7 +103,7 @@ class EngineArgs: ...@@ -92,7 +103,7 @@ class EngineArgs:
help='Name or path of the huggingface model to use.') help='Name or path of the huggingface model to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=str, type=nullable_str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.') help='Name or path of the huggingface tokenizer to use.')
parser.add_argument( parser.add_argument(
...@@ -101,21 +112,21 @@ class EngineArgs: ...@@ -101,21 +112,21 @@ class EngineArgs:
help='Skip initialization of tokenizer and detokenizer') help='Skip initialization of tokenizer and detokenizer')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific tokenizer version to use. It can be a branch ' help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
...@@ -132,7 +143,7 @@ class EngineArgs: ...@@ -132,7 +143,7 @@ class EngineArgs:
action='store_true', action='store_true',
help='Trust remote code from huggingface.') help='Trust remote code from huggingface.')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=nullable_str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='Directory to download and load the weights, ' help='Directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
...@@ -183,7 +194,7 @@ class EngineArgs: ...@@ -183,7 +194,7 @@ class EngineArgs:
'supported for common inference criteria.') 'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=nullable_str,
default=None, default=None,
help='Path to the JSON file containing the KV cache ' help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when ' 'scaling factors. This should generally be supplied, when '
...@@ -300,7 +311,7 @@ class EngineArgs: ...@@ -300,7 +311,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=nullable_str,
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
...@@ -319,6 +330,14 @@ class EngineArgs: ...@@ -319,6 +330,14 @@ class EngineArgs:
default=EngineArgs.max_context_len_to_capture, default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA ' help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce', parser.add_argument('--disable-custom-all-reduce',
action='store_true', action='store_true',
...@@ -337,7 +356,7 @@ class EngineArgs: ...@@ -337,7 +356,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored ' 'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.') 'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config', parser.add_argument('--tokenizer-pool-extra-config',
type=str, type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config, default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. ' help='Extra config for tokenizer pool. '
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
...@@ -376,6 +395,14 @@ class EngineArgs: ...@@ -376,6 +395,14 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. ' help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. ' 'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.')) 'Defaults to max_num_seqs.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
...@@ -384,7 +411,7 @@ class EngineArgs: ...@@ -384,7 +411,7 @@ class EngineArgs:
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
parser.add_argument( parser.add_argument(
'--image-input-type', '--image-input-type',
type=str, type=nullable_str,
default=None, default=None,
choices=[ choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType t.name.lower() for t in VisionLanguageConfig.ImageInputType
...@@ -397,7 +424,7 @@ class EngineArgs: ...@@ -397,7 +424,7 @@ class EngineArgs:
help=('Input id for image token.')) help=('Input id for image token.'))
parser.add_argument( parser.add_argument(
'--image-input-shape', '--image-input-shape',
type=str, type=nullable_str,
default=None, default=None,
help=('The biggest image input shape (worst for memory footprint) ' help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.')) 'given an input type. Only used for vLLM\'s profile_run.'))
...@@ -420,7 +447,7 @@ class EngineArgs: ...@@ -420,7 +447,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
type=str, type=nullable_str,
default=EngineArgs.speculative_model, default=EngineArgs.speculative_model,
help= help=
'The name of the draft model to be used in speculative decoding.') 'The name of the draft model to be used in speculative decoding.')
...@@ -434,14 +461,28 @@ class EngineArgs: ...@@ -434,14 +461,28 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--speculative-max-model-len', '--speculative-max-model-len',
type=str, type=int,
default=EngineArgs.speculative_max_model_len, default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the ' help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip ' 'draft model. Sequences over this length will skip '
'speculation.') 'speculation.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=str, type=nullable_str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. ' help='Extra config for model loader. '
'This will be passed to the model loader ' 'This will be passed to the model loader '
...@@ -449,6 +490,21 @@ class EngineArgs: ...@@ -449,6 +490,21 @@ class EngineArgs:
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
'parsed into a dictionary.') 'parsed into a dictionary.')
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s)"
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
return parser return parser
@classmethod @classmethod
...@@ -467,7 +523,8 @@ class EngineArgs: ...@@ -467,7 +523,8 @@ class EngineArgs:
self.code_revision, self.tokenizer_revision, self.max_model_len, self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path, self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture, self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init) self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
...@@ -493,6 +550,8 @@ class EngineArgs: ...@@ -493,6 +550,8 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
...@@ -509,6 +568,7 @@ class EngineArgs: ...@@ -509,6 +568,7 @@ class EngineArgs:
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size, lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
......
import asyncio import asyncio
import os
import time import time
from functools import partial from functools import partial
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
...@@ -7,20 +6,21 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, ...@@ -7,20 +6,21 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int( ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
class AsyncEngineDeadError(RuntimeError): class AsyncEngineDeadError(RuntimeError):
...@@ -117,7 +117,7 @@ class RequestTracker: ...@@ -117,7 +117,7 @@ class RequestTracker:
self._request_streams[request_id].put(request_output) self._request_streams[request_id].put(request_output)
if request_output.finished: if request_output.finished:
if verbose: if verbose:
logger.info(f"Finished request {request_id}.") logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def process_exception(self, def process_exception(self,
...@@ -128,7 +128,7 @@ class RequestTracker: ...@@ -128,7 +128,7 @@ class RequestTracker:
"""Propagate an exception from the engine.""" """Propagate an exception from the engine."""
self._request_streams[request_id].put(exception) self._request_streams[request_id].put(exception)
if verbose: if verbose:
logger.info(f"Finished request {request_id}.") logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self, request_id: str,
...@@ -151,7 +151,7 @@ class RequestTracker: ...@@ -151,7 +151,7 @@ class RequestTracker:
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration.""" """Abort a request during next background loop iteration."""
if verbose: if verbose:
logger.info(f"Aborted request {request_id}.") logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id) self._finished_requests.put_nowait(request_id)
...@@ -210,20 +210,25 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -210,20 +210,25 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, execute_model_req)
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
else: else:
output = [] output = []
request_outputs = self._process_model_outputs( request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups, output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups) scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats. # Log stats.
if self.log_stats: self.do_log_stats(scheduler_outputs, output)
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs return request_outputs
...@@ -521,11 +526,11 @@ class AsyncLLMEngine: ...@@ -521,11 +526,11 @@ class AsyncLLMEngine:
if shortened_token_ids is not None: if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self. shortened_token_ids = shortened_token_ids[:self.
max_log_len] max_log_len]
logger.info(f"Received request {request_id}: " logger.info(
f"prompt: {shortened_prompt!r}, " "Received request %s: prompt: %r, "
f"sampling_params: {sampling_params}, " "sampling_params: %s, prompt_token_ids: %s, "
f"prompt_token_ids: {shortened_token_ids}, " "lora_request: %s.", request_id, shortened_prompt,
f"lora_request: {lora_request}.") sampling_params, shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
...@@ -697,9 +702,21 @@ class AsyncLLMEngine: ...@@ -697,9 +702,21 @@ class AsyncLLMEngine:
else: else:
return self.engine.get_model_config() return self.engine.get_model_config()
async def do_log_stats(self) -> None: async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.do_log_stats.remote() # type: ignore await self.engine.do_log_stats.remote( # type: ignore
scheduler_outputs, model_output)
else: else:
self.engine.do_log_stats() self.engine.do_log_stats()
...@@ -717,4 +734,4 @@ class AsyncLLMEngine: ...@@ -717,4 +734,4 @@ class AsyncLLMEngine:
raise RuntimeError("Engine is dead.") from e raise RuntimeError("Engine is dead.") from e
else: else:
await self.engine.check_health_async() await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s") logger.debug("Health check took %fs", time.perf_counter() - t)
...@@ -8,21 +8,23 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, ...@@ -8,21 +8,23 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
SequenceGroup, SequenceStage) Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
...@@ -96,29 +98,39 @@ class LLMEngine: ...@@ -96,29 +98,39 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None: ) -> None:
logger.info( logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: " "Initializing an LLM engine (v%s) with config: "
f"model={model_config.model!r}, " "model=%r, speculative_config=%r, tokenizer=%r, "
f"speculative_config={speculative_config!r}, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
f"tokenizer={model_config.tokenizer!r}, " "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, " "max_seq_len=%d, download_dir=%r, load_format=%s, "
f"tokenizer_mode={model_config.tokenizer_mode}, " "tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
f"revision={model_config.revision}, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
f"tokenizer_revision={model_config.tokenizer_revision}, " "quantization_param_path=%s, device_config=%s, "
f"trust_remote_code={model_config.trust_remote_code}, " "decoding_config=%r, seed=%d, served_model_name=%s)",
f"dtype={model_config.dtype}, " vllm.__version__,
f"max_seq_len={model_config.max_model_len}, " model_config.model,
f"download_dir={load_config.download_dir!r}, " speculative_config,
f"load_format={load_config.load_format}, " model_config.tokenizer,
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " model_config.skip_tokenizer_init,
f"disable_custom_all_reduce=" model_config.tokenizer_mode,
f"{parallel_config.disable_custom_all_reduce}, " model_config.revision,
f"quantization={model_config.quantization}, " model_config.tokenizer_revision,
f"enforce_eager={model_config.enforce_eager}, " model_config.trust_remote_code,
f"kv_cache_dtype={cache_config.cache_dtype}, " model_config.dtype,
f"quantization_param_path={model_config.quantization_param_path}, " model_config.max_model_len,
f"device_config={device_config.device}, " load_config.download_dir,
f"decoding_config={decoding_config!r}, " load_config.load_format,
f"seed={model_config.seed})") parallel_config.tensor_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
model_config.seed,
model_config.served_model_name,
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
...@@ -208,7 +220,8 @@ class LLMEngine: ...@@ -208,7 +220,8 @@ class LLMEngine:
if self.log_stats: if self.log_stats:
self.stat_logger = StatLogger( self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model)) labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config) self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or # Create sequence output processor, e.g. for beam search or
...@@ -237,8 +250,10 @@ class LLMEngine: ...@@ -237,8 +250,10 @@ class LLMEngine:
if self.cache_config.num_gpu_blocks_override is not None: if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(f"Overriding {num_gpu_blocks=} with " logger.info(
f"{num_gpu_blocks_override=}") "Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
...@@ -287,6 +302,12 @@ class LLMEngine: ...@@ -287,6 +302,12 @@ class LLMEngine:
# the closure used to initialize Ray worker actors # the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!") raise RuntimeError("LLMEngine should not be pickled!")
def __del__(self):
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None) return self.tokenizer.get_lora_tokenizer(None)
...@@ -414,9 +435,10 @@ class LLMEngine: ...@@ -414,9 +435,10 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens # Add the eos token id into the sampling_params to support min_tokens
# processing # processing
sampling_params.eos_token_id = seq.eos_token_id if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields) self.generation_config_fields)
...@@ -450,6 +472,10 @@ class LLMEngine: ...@@ -450,6 +472,10 @@ class LLMEngine:
"""Gets the model configuration.""" """Gets the model configuration."""
return self.model_config return self.model_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()
...@@ -459,9 +485,12 @@ class LLMEngine: ...@@ -459,9 +485,12 @@ class LLMEngine:
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _process_model_outputs( def _process_model_outputs(
self, output: List[SamplerOutput], self,
scheduled_seq_groups: List[SequenceGroup], output: List[SamplerOutput],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups. """Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client. Returns RequestOutputs that can be returned to the client.
...@@ -475,17 +504,15 @@ class LLMEngine: ...@@ -475,17 +504,15 @@ class LLMEngine:
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, for scheduled_seq_group, outputs, seq_group_meta in zip(
output_by_sequence_group): scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
# If all sequences in the sequence group are in DECODE, then we can self.output_processor.process_prompt_logprob(seq_group, outputs)
# process the output tokens. Otherwise, they are (chunked) prefill if seq_group_meta.do_sample:
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs) self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
...@@ -557,30 +584,36 @@ class LLMEngine: ...@@ -557,30 +584,36 @@ class LLMEngine:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots) num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else: else:
output = [] output = []
request_outputs = self._process_model_outputs( request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups, output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups) scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats. # Log stats.
if self.log_stats: self.do_log_stats(scheduler_outputs, output)
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output=output))
return request_outputs return request_outputs
def do_log_stats(self) -> None: def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None)) self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output))
def _get_stats( def _get_stats(
self, self,
...@@ -596,59 +629,109 @@ class LLMEngine: ...@@ -596,59 +629,109 @@ class LLMEngine:
""" """
now = time.time() now = time.time()
# KV Cache Usage in %. # System State
# Scheduler State
num_running_sys = len(self.scheduler.running)
num_swapped_sys = len(self.scheduler.swapped)
num_waiting_sys = len(self.scheduler.waiting)
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage = 0. cpu_cache_usage_sys = 0.
if num_total_cpu > 0: if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
) )
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Scheduler State # Iteration stats
num_running = len(self.scheduler.running) num_prompt_tokens_iter = 0
num_swapped = len(self.scheduler.swapped) num_generation_tokens_iter = 0
num_waiting = len(self.scheduler.waiting) time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Iteration stats if we have scheduler output.
num_prompt_tokens = 0 # Request stats
num_generation_tokens = 0 # Latency
time_to_first_tokens = [] time_e2e_requests: List[float] = []
time_per_output_tokens = [] # Metadata
time_e2e_requests = [] num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None: if scheduler_outputs is not None:
prompt_run = scheduler_outputs.num_prefill_groups > 0 num_generation_tokens_from_prefill_groups = 0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# Number of Tokens. # the len of scheduler_outputs.scheduled_seq_groups is !=
if prompt_run: # scheduler_outputs.num_prefill_groups, this means that
num_prompt_tokens = sum( # chunked prefills have been detected.
len(scheduled_seq_group.seq_group.prompt_token_ids)
for scheduled_seq_group in for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups):
num_generation_tokens = sum( group_was_prefill = idx < scheduler_outputs.num_prefill_groups
scheduled_seq_group.seq_group.num_seqs()
for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings.
time_last_iters = []
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time) # NOTE: a seq_group that completed all of its prefill tokens
time_last_iters.append(seq_group.get_last_latency(now)) # in the last iteration will have seq_group.is_prefill() = False
# Time since arrival for all finished requests. # with group_was_prefill = True
if group_was_prefill:
# Number of prompt tokens.
num_prompt_tokens_iter += (
scheduled_seq_group.token_chunk_size)
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
time_to_first_tokens_iter.append(latency)
# One generation token per finished prefill.
num_generation_tokens_from_prefill_groups += (
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
time_per_output_tokens_iter.append(latency)
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
# the same metadata more than once per request, we standardize
# on logging request level information for finished requests,
# which can only happen once.
if seq_group.is_finished(): if seq_group.is_finished():
# Latency timings
time_e2e_requests.append(now - time_e2e_requests.append(now -
seq_group.metrics.arrival_time) seq_group.metrics.arrival_time)
time_to_first_tokens = time_last_iters if prompt_run else [] # Metadata
time_per_output_tokens = [] if prompt_run else time_last_iters num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
num_generation_tokens_requests.extend([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
best_of_requests.append(seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()
])
# Number of generation tokens.
# num_batched_tokens equals the number of prompt_tokens plus the
# number of decode_tokens in a single iteration. So,
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in # Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output. # sampler output.
...@@ -660,17 +743,32 @@ class LLMEngine: ...@@ -660,17 +743,32 @@ class LLMEngine:
return Stats( return Stats(
now=now, now=now,
num_running=num_running,
num_swapped=num_swapped, # System stats
num_waiting=num_waiting, # Scheduler State
gpu_cache_usage=gpu_cache_usage, num_running_sys=num_running_sys,
cpu_cache_usage=cpu_cache_usage, num_swapped_sys=num_swapped_sys,
num_prompt_tokens=num_prompt_tokens, num_waiting_sys=num_waiting_sys,
num_generation_tokens=num_generation_tokens, # KV Cache Usage in %
time_to_first_tokens=time_to_first_tokens, gpu_cache_usage_sys=gpu_cache_usage_sys,
time_per_output_tokens=time_per_output_tokens, cpu_cache_usage_sys=cpu_cache_usage_sys,
time_e2e_requests=time_e2e_requests,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics, spec_decode_metrics=spec_decode_metrics,
# Request stats
# Latency
time_e2e_requests=time_e2e_requests,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
) )
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
...@@ -21,8 +23,9 @@ disable_created_metrics() ...@@ -21,8 +23,9 @@ disable_created_metrics()
# begin-metrics-definitions # begin-metrics-definitions
class Metrics: class Metrics:
labelname_finish_reason = "finished_reason"
def __init__(self, labelnames: List[str]): def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors # Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names): for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name: if hasattr(collector, "_name") and "vllm" in collector._name:
...@@ -34,18 +37,20 @@ class Metrics: ...@@ -34,18 +37,20 @@ class Metrics:
documentation='information of cache_config') documentation='information of cache_config')
# System stats # System stats
# Scheduler State
self.gauge_scheduler_running = Gauge( self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.", documentation="Number of requests currently running on GPU.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge( self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge( self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
...@@ -55,7 +60,7 @@ class Metrics: ...@@ -55,7 +60,7 @@ class Metrics:
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
# Raw stats from last model iteration # Iteration stats
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
...@@ -80,18 +85,51 @@ class Metrics: ...@@ -80,18 +85,51 @@ class Metrics:
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5 1.0, 2.5
]) ])
self.histogram_e2e_request_latency = Histogram(
# Request stats
# Latency
self.histogram_e2e_time_request = Histogram(
name="vllm:e2e_request_latency_seconds", name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.", documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = Histogram(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = Histogram(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = Histogram(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = Counter(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Legacy metrics # Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge( self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s", name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.", documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge( self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s", name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.", documentation="Average generation throughput in tokens/s.",
...@@ -102,24 +140,57 @@ class Metrics: ...@@ -102,24 +140,57 @@ class Metrics:
# end-metrics-definitions # end-metrics-definitions
def build_1_2_5_buckets(max_value: int):
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
exponent = 0
buckets = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
@dataclass @dataclass
class Stats: class Stats:
"""Created by LLMEngine for use by StatLogger.""" """Created by LLMEngine for use by StatLogger."""
now: float now: float
# System stats. # System stats (should have _sys suffix)
num_running: int # Scheduler State
num_waiting: int num_running_sys: int
num_swapped: int num_waiting_sys: int
gpu_cache_usage: float num_swapped_sys: int
cpu_cache_usage: float # KV Cache Usage in %
gpu_cache_usage_sys: float
# Raw stats from last model iteration. cpu_cache_usage_sys: float
num_prompt_tokens: int
num_generation_tokens: int # Iteration stats (should have _iter suffix)
time_to_first_tokens: List[float] num_prompt_tokens_iter: int
time_per_output_tokens: List[float] num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float] time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
...@@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol): ...@@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol):
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally. # Metadata for logging locally.
self.last_local_log = time.time() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
...@@ -144,7 +216,8 @@ class StatLogger: ...@@ -144,7 +216,8 @@ class StatLogger:
# Prometheus metrics # Prometheus metrics
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys())) self.metrics = Metrics(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
...@@ -158,34 +231,66 @@ class StatLogger: ...@@ -158,34 +231,66 @@ class StatLogger:
return elapsed_time > self.local_interval return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges. # System state data
self.metrics.gauge_scheduler_running.labels(**self.labels).set( self._log_gauge(self.metrics.gauge_scheduler_running,
stats.num_running) stats.num_running_sys)
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( self._log_gauge(self.metrics.gauge_scheduler_swapped,
stats.num_swapped) stats.num_swapped_sys)
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( self._log_gauge(self.metrics.gauge_scheduler_waiting,
stats.num_waiting) stats.num_waiting_sys)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( self._log_gauge(self.metrics.gauge_gpu_cache_usage,
stats.gpu_cache_usage) stats.gpu_cache_usage_sys)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage) stats.cpu_cache_usage_sys)
# Add to token counters. # Iteration level data
self.metrics.counter_prompt_tokens.labels(**self.labels).inc( self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens) stats.num_prompt_tokens_iter)
self.metrics.counter_generation_tokens.labels(**self.labels).inc( self._log_counter(self.metrics.counter_generation_tokens,
stats.num_generation_tokens) stats.num_generation_tokens_iter)
self._log_histogram(self.metrics.histogram_time_to_first_token,
# Observe request level latencies in histograms. stats.time_to_first_tokens_iter)
for ttft in stats.time_to_first_tokens: self._log_histogram(self.metrics.histogram_time_per_output_token,
self.metrics.histogram_time_to_first_token.labels( stats.time_per_output_tokens_iter)
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens: # Request level data
self.metrics.histogram_time_per_output_token.labels( # Latency
**self.labels).observe(tpot) self._log_histogram(self.metrics.histogram_e2e_time_request,
for e2e in stats.time_e2e_requests: stats.time_e2e_requests)
self.metrics.histogram_e2e_request_latency.labels( # Metadata
**self.labels).observe(e2e) finished_reason_counter = CollectionsCounter(
stats.finished_reason_requests)
self._log_counter_labels(self.metrics.counter_request_success,
finished_reason_counter,
Metrics.labelname_finish_reason)
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests)
self._log_histogram(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
...@@ -210,8 +315,8 @@ class StatLogger: ...@@ -210,8 +315,8 @@ class StatLogger:
self._log_prometheus(stats) self._log_prometheus(stats)
# Save tracked stats for token counters. # Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens) self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now): if self._local_interval_elapsed(stats.now):
...@@ -227,14 +332,19 @@ class StatLogger: ...@@ -227,14 +332,19 @@ class StatLogger:
# Log to stdout. # Log to stdout.
logger.info( logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
f"Avg generation throughput: " "Avg generation throughput: %.1f tokens/s, "
f"{generation_throughput:.1f} tokens/s, " "Running: %d reqs, Swapped: %d reqs, "
f"Running: {stats.num_running} reqs, " "Pending: %d reqs, GPU KV cache usage: %.1f%%, "
f"Swapped: {stats.num_swapped} reqs, " "CPU KV cache usage: %.1f%%",
f"Pending: {stats.num_waiting} reqs, " prompt_throughput,
f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " generation_throughput,
f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []
......
...@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
scheduler. scheduler.
""" """
pass pass
@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass
import functools
from typing import Callable, List from typing import Callable, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -44,6 +45,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -44,6 +45,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker self.stop_checker = stop_checker
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self._log_prompt_logprob_unsupported_warning_once()
@staticmethod
@functools.lru_cache()
def _log_prompt_logprob_unsupported_warning_once():
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group. """Append new tokens in the outputs to sequences in the sequence group.
...@@ -80,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -80,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
valid_samples: List[SequenceOutput], valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples] output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
# Truncate to max_tokens if necessary. # Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
...@@ -104,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -104,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Incrementally append tokens to the sequence, as if we had only one new # Incrementally append tokens to the sequence, as if we had only one new
# token. # token.
for output_token_id in output_token_ids: for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
seq.append_token_id( seq.append_token_id(
token_id=output_token_id, token_id=output_token_id,
# TODO emit logprobs in multi-step decoding. logprobs=output_logprob,
logprobs={output_token_id: Logprob(0.0)},
) )
new_char_count = 0 new_char_count = 0
......
...@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
), f"{type(self)} does not support multiple outputs per step" ), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0]) return self._process_sequence_group_outputs(sequence_group, outputs[0])
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None: outputs: List[SequenceGroupOutput]) -> None:
assert len(outputs) == 1, ("Single step should only has 1 output.")
# Process prompt logprobs output = outputs[0]
prompt_logprobs = outputs.prompt_logprobs prompt_logprobs = output.prompt_logprobs
if prompt_logprobs is not None and \ if (prompt_logprobs is not None
seq_group.sampling_params.detokenize and self.detokenizer: and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace( self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs) seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
......
from typing import List from typing import List
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], def create_output_by_sequence_group(
num_seq_groups: int): sampler_outputs: List[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment