Unverified Commit b029de99 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Optimization] Make new_block_ids None if empty (#23262)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent bbea1cef
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Literal, Optional, overload
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -37,7 +37,24 @@ class KVCacheBlocks: ...@@ -37,7 +37,24 @@ class KVCacheBlocks:
tuple(blk1 + blk2 tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks))) for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> tuple[list[int], ...]: @overload
def get_block_ids(
self,
allow_none: Literal[False] = False,
) -> tuple[list[int], ...]:
...
@overload
def get_block_ids(
self,
allow_none: Literal[True] = True,
) -> Optional[tuple[list[int], ...]]:
...
def get_block_ids(
self,
allow_none: bool = False,
):
""" """
Converts the KVCacheBlocks instance to block_ids. Converts the KVCacheBlocks instance to block_ids.
...@@ -46,6 +63,8 @@ class KVCacheBlocks: ...@@ -46,6 +63,8 @@ class KVCacheBlocks:
* the outer tuple corresponds to KV cache groups * the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group * each inner list contains the block_ids of the blocks in that group
""" """
if allow_none and all(len(group) == 0 for group in self.blocks):
return None
return tuple([blk.block_id for blk in group] for group in self.blocks) return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]: def get_unhashed_block_ids(self) -> list[int]:
...@@ -348,10 +367,13 @@ class KVCacheManager: ...@@ -348,10 +367,13 @@ class KVCacheManager:
""" """
return self.block_pool.take_events() return self.block_pool.take_events()
def get_blocks(self, request_id: str) -> KVCacheBlocks:
"""Get the blocks of a request."""
return KVCacheBlocks(self.coordinator.get_blocks(request_id))
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request.""" """Get the block ids of a request."""
return KVCacheBlocks( return self.get_blocks(request_id).get_block_ids()
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled."""
......
...@@ -91,7 +91,7 @@ class CachedRequestData: ...@@ -91,7 +91,7 @@ class CachedRequestData:
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty. # When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]] new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int] num_computed_tokens: list[int]
@property @property
......
...@@ -19,7 +19,7 @@ from vllm.logger import init_logger ...@@ -19,7 +19,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
...@@ -185,7 +185,7 @@ class Scheduler(SchedulerInterface): ...@@ -185,7 +185,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding. # uses structured decoding.
structured_output_request_ids: dict[str, int] = {} structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
# Encoder-related. # Encoder-related.
...@@ -288,8 +288,7 @@ class Scheduler(SchedulerInterface): ...@@ -288,8 +288,7 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = ( req_to_new_blocks[request.request_id] = new_blocks
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -496,8 +495,8 @@ class Scheduler(SchedulerInterface): ...@@ -496,8 +495,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = ( req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id)) self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
...@@ -546,8 +545,8 @@ class Scheduler(SchedulerInterface): ...@@ -546,8 +545,8 @@ class Scheduler(SchedulerInterface):
) )
# Construct the scheduler output. # Construct the scheduler output.
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request(req, NewRequestData.from_request(
req_to_new_block_ids[req.request_id]) req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
cached_reqs_data = self._make_cached_request_data( cached_reqs_data = self._make_cached_request_data(
...@@ -555,7 +554,7 @@ class Scheduler(SchedulerInterface): ...@@ -555,7 +554,7 @@ class Scheduler(SchedulerInterface):
scheduled_resumed_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_blocks,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -628,11 +627,11 @@ class Scheduler(SchedulerInterface): ...@@ -628,11 +627,11 @@ class Scheduler(SchedulerInterface):
resumed_reqs: list[Request], resumed_reqs: list[Request],
num_scheduled_tokens: dict[str, int], num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]], spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]], req_to_new_blocks: dict[str, KVCacheBlocks],
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = [] new_block_ids: list[Optional[tuple[list[int], ...]]] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
use_connector = self.connector is not None use_connector = self.connector is not None
...@@ -655,7 +654,8 @@ class Scheduler(SchedulerInterface): ...@@ -655,7 +654,8 @@ class Scheduler(SchedulerInterface):
# out of bounds errors. TODO: Remove this once the KVConnector # out of bounds errors. TODO: Remove this once the KVConnector
# is updated to handle token IDs properly. # is updated to handle token IDs properly.
new_token_ids.append([]) new_token_ids.append([])
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do # Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list. # in-place appending so that we don't need to allocate a new list.
......
...@@ -574,11 +574,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -574,11 +574,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the block IDs. # Update the block IDs.
if not resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. if new_block_ids is not None:
for block_ids, new_ids in zip(req_state.block_ids, # Append the new blocks to the existing block IDs.
new_block_ids): for block_ids, new_ids in zip(req_state.block_ids,
block_ids.extend(new_ids) new_block_ids):
block_ids.extend(new_ids)
else: else:
assert new_block_ids is not None
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids req_state.block_ids = new_block_ids
...@@ -594,7 +596,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -594,7 +596,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) if new_block_ids is not None:
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
......
...@@ -418,11 +418,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -418,11 +418,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
if not resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. if new_block_ids is not None:
for block_ids, new_ids in zip(req_state.block_ids, # Append the new blocks to the existing block IDs.
new_block_ids): for block_ids, new_ids in zip(req_state.block_ids,
block_ids.extend(new_ids) new_block_ids):
block_ids.extend(new_ids)
else: else:
assert new_block_ids is not None
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids req_state.block_ids = new_block_ids
...@@ -438,7 +440,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -438,7 +440,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) if new_block_ids is not None:
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment