Unverified Commit 4a0f7888 authored by amit's avatar amit Committed by GitHub
Browse files

[Core] feat: Implement Priority Scheduling in V1 Engine (#19057)


Signed-off-by: default avataramit <amit.man@gmail.com>
Co-authored-by: default avatarRoger Wang <Rogerw0108@gmail.com>
parent c4cf2606
...@@ -45,6 +45,18 @@ For each item, our progress towards V1 support falls into one of the following s ...@@ -45,6 +45,18 @@ For each item, our progress towards V1 support falls into one of the following s
- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later. - **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
- **🔴 Deprecated**: Not planned for V1 unless there is strong demand. - **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
!!! note
vLLM V1’s unified scheduler treats both prompt and output tokens the same
way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically
allocate a fixed token budget per request, enabling features like chunked prefills,
prefix caching, and speculative decoding without a strict separation between prefill
and decode phases.
The V1 scheduler supports multiple scheduling policies, including First-Come,
First-Served (FCFS) and priority-based scheduling (where requests are processed
based on assigned priority, with FCFS as a tie-breaker), configurable via the
`--scheduling-policy` argument.
### Hardware ### Hardware
| Hardware | Status | | Hardware | Status |
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import heapq
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator
from enum import Enum
from vllm.v1.request import Request
class SchedulingPolicy(Enum):
"""Enum for scheduling policies."""
FCFS = "fcfs"
PRIORITY = "priority"
class RequestQueue(ABC):
"""Abstract base class for request queues."""
@abstractmethod
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to the policy."""
pass
@abstractmethod
def pop_request(self) -> Request:
"""Pop a request from the queue according to the policy."""
pass
@abstractmethod
def peek_request(self) -> Request:
"""Peek at the request at the front of the queue without removing it."""
pass
@abstractmethod
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
pass
@abstractmethod
def prepend_requests(self, requests: RequestQueue) -> None:
"""Prepend all requests from another queue to the front of this
queue."""
pass
@abstractmethod
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
pass
@abstractmethod
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
pass
@abstractmethod
def __bool__(self) -> bool:
"""Check if queue has any requests."""
pass
@abstractmethod
def __len__(self) -> int:
"""Get number of requests in queue."""
pass
@abstractmethod
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to the policy."""
pass
@abstractmethod
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
pass
class FCFSRequestQueue(deque[Request], RequestQueue):
"""A first-come-first-served queue that supports deque operations."""
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to FCFS policy."""
self.append(request)
def pop_request(self) -> Request:
"""Pop a request from the queue according to FCFS policy."""
return self.popleft()
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self:
raise IndexError("peek from an empty queue")
return self[0]
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
self.appendleft(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Prepend all requests from another queue to the front of this
queue."""
self.extendleft(reversed(requests))
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self.remove(request)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = set(requests)
filtered_requests = [
req for req in self if req not in requests_to_remove
]
# deque does not support in-place filtering, so we need to clear
# and extend
self.clear()
self.extend(filtered_requests)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return len(self) > 0
def __len__(self) -> int:
"""Get number of requests in queue."""
return super().__len__()
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to FCFS policy."""
return super().__iter__()
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
return super().__reversed__()
class PriorityRequestQueue(RequestQueue):
"""
A priority queue that supports heap operations.
Requests with a smaller value of `priority` are processed first.
If multiple requests have the same priority, the one with the earlier
`arrival_time` is processed first.
"""
def __init__(self) -> None:
self._heap: list[tuple[int, float, Request]] = []
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy."""
heapq.heappush(self._heap,
(request.priority, request.arrival_time, request))
def pop_request(self) -> Request:
"""Pop a request from the queue according to priority policy."""
if not self._heap:
raise IndexError("pop from empty heap")
_, _, request = heapq.heappop(self._heap)
return request
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
_, _, request = self._heap[0]
return request
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap = [(p, t, r) for p, t, r in self._heap if r != request]
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = set(requests)
self._heap = [(p, t, r) for p, t, r in self._heap
if r not in requests_to_remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return bool(self._heap)
def __len__(self) -> int:
"""Get number of requests in queue."""
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to priority policy."""
heap_copy = self._heap[:]
while heap_copy:
_, _, request = heapq.heappop(heap_copy)
yield request
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse priority order."""
return reversed(list(self))
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""
if policy == SchedulingPolicy.PRIORITY:
return PriorityRequestQueue()
elif policy == SchedulingPolicy.FCFS:
return FCFSRequestQueue()
else:
raise ValueError(f"Unknown scheduling policy: {policy}")
...@@ -22,6 +22,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager ...@@ -22,6 +22,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
...@@ -94,8 +96,16 @@ class Scheduler(SchedulerInterface): ...@@ -94,8 +96,16 @@ class Scheduler(SchedulerInterface):
# req_id -> Request # req_id -> Request
self.requests: dict[str, Request] = {} self.requests: dict[str, Request] = {}
# Scheduling policy
if self.scheduler_config.policy == "priority":
self.policy = SchedulingPolicy.PRIORITY
elif self.scheduler_config.policy == "fcfs":
self.policy = SchedulingPolicy.FCFS
else:
raise ValueError(
f"Unknown scheduling policy: {self.scheduler_config.policy}")
# Priority queues for requests. # Priority queues for requests.
self.waiting: deque[Request] = deque() self.waiting = create_request_queue(self.policy)
self.running: list[Request] = [] self.running: list[Request] = []
# The request IDs that are finished in between the previous and the # The request IDs that are finished in between the previous and the
...@@ -247,7 +257,15 @@ class Scheduler(SchedulerInterface): ...@@ -247,7 +257,15 @@ class Scheduler(SchedulerInterface):
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
preempted_req = self.running.pop() if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
...@@ -255,7 +273,7 @@ class Scheduler(SchedulerInterface): ...@@ -255,7 +273,7 @@ class Scheduler(SchedulerInterface):
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.appendleft(preempted_req) self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
if preempted_req == request: if preempted_req == request:
# No more request to preempt. # No more request to preempt.
...@@ -311,9 +329,9 @@ class Scheduler(SchedulerInterface): ...@@ -311,9 +329,9 @@ class Scheduler(SchedulerInterface):
if req.lora_request and req.lora_request.lora_int_id > 0) if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped # Use a temporary RequestQueue to collect requests that need to be
# and put back at the head of the waiting queue later # skipped and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque() skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
...@@ -321,7 +339,7 @@ class Scheduler(SchedulerInterface): ...@@ -321,7 +339,7 @@ class Scheduler(SchedulerInterface):
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
request = self.waiting[0] request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
...@@ -332,8 +350,8 @@ class Scheduler(SchedulerInterface): ...@@ -332,8 +350,8 @@ class Scheduler(SchedulerInterface):
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id) request.request_id)
self.waiting.popleft() self.waiting.pop_request()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.prepend_request(request)
continue continue
# Skip request if the structured output request is still waiting # Skip request if the structured output request is still waiting
...@@ -343,19 +361,18 @@ class Scheduler(SchedulerInterface): ...@@ -343,19 +361,18 @@ class Scheduler(SchedulerInterface):
if structured_output_req and structured_output_req.grammar: if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else: else:
self.waiting.popleft() self.waiting.pop_request()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.prepend_request(request)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if self.lora_config and request.lora_request and ( if (self.lora_config and request.lora_request and
len(scheduled_loras) == self.lora_config.max_loras (len(scheduled_loras) == self.lora_config.max_loras and
and request.lora_request.lora_int_id request.lora_request.lora_int_id not in scheduled_loras)):
not in scheduled_loras):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.popleft() self.waiting.pop_request()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.prepend_request(request)
continue continue
num_external_computed_tokens = 0 num_external_computed_tokens = 0
...@@ -407,8 +424,8 @@ class Scheduler(SchedulerInterface): ...@@ -407,8 +424,8 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \ if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget: num_new_tokens > token_budget:
self.waiting.popleft() self.waiting.pop_request()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.prepend_request(request)
continue continue
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
...@@ -448,17 +465,19 @@ class Scheduler(SchedulerInterface): ...@@ -448,17 +465,19 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens, num_external_computed_tokens,
) )
self.waiting.popleft() # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async: if load_kv_async:
# If loading async, allocate memory and put request # If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state. # into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
if request.use_structured_output: if request.use_structured_output:
structured_output_request_ids[ structured_output_request_ids[request.request_id] = (
request.request_id] = req_index req_index)
req_index += 1 req_index += 1
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
...@@ -494,7 +513,7 @@ class Scheduler(SchedulerInterface): ...@@ -494,7 +513,7 @@ class Scheduler(SchedulerInterface):
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests) self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
...@@ -896,7 +915,7 @@ class Scheduler(SchedulerInterface): ...@@ -896,7 +915,7 @@ class Scheduler(SchedulerInterface):
return len(self.running), len(self.waiting) return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.add_request(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED) request.record_event(EngineCoreEventType.QUEUED)
...@@ -917,16 +936,31 @@ class Scheduler(SchedulerInterface): ...@@ -917,16 +936,31 @@ class Scheduler(SchedulerInterface):
else: else:
request_ids = set(request_ids) request_ids = set(request_ids)
running_requests_to_remove = []
waiting_requests_to_remove = []
valid_requests = []
# First pass: collect requests to remove from queues
for req_id in request_ids: for req_id in request_ids:
request = self.requests.get(req_id) request = self.requests.get(req_id)
if request is None: if request is None:
# Invalid request ID. # Invalid request ID.
continue continue
valid_requests.append(request)
if request.status == RequestStatus.RUNNING: if request.status == RequestStatus.RUNNING:
self.running.remove(request) running_requests_to_remove.append(request)
else: else:
self.waiting.remove(request) waiting_requests_to_remove.append(request)
# Remove all requests from queues at once for better efficiency
for request in running_requests_to_remove:
self.running.remove(request)
if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove)
# Second pass: set status and free requests
for request in valid_requests:
request.status = finished_status request.status = finished_status
self._free_request(request) self._free_request(request)
......
...@@ -68,6 +68,7 @@ class EngineCoreRequest( ...@@ -68,6 +68,7 @@ class EngineCoreRequest(
# belong to, to cover a race condition where the request is sent before # belong to, to cover a race condition where the request is sent before
# a wave finished notification is received. # a wave finished notification is received.
current_wave: int = 0 current_wave: int = 0
priority: int = 0
class EngineCoreEventType(enum.IntEnum): class EngineCoreEventType(enum.IntEnum):
......
...@@ -219,8 +219,6 @@ class Processor: ...@@ -219,8 +219,6 @@ class Processor:
# TODO(woosuk): Support encoder-decoder models. # TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params, lora_request) self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None: if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.") raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None: if prompt_adapter_request is not None:
...@@ -340,6 +338,7 @@ class Processor: ...@@ -340,6 +338,7 @@ class Processor:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"), cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
...@@ -30,18 +31,23 @@ class Request: ...@@ -30,18 +31,23 @@ class Request:
pooling_params: Optional[PoolingParams], pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int], eos_token_id: Optional[int],
client_index: int = 0, client_index: int = 0,
arrival_time: Optional[float] = None,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
priority: int = 0,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index self.client_index = client_index
self.priority = priority
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.pooling_params = pooling_params self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.structured_output_request = structured_output_request self.structured_output_request = structured_output_request
self.arrival_time = arrival_time if arrival_time is not None else \
time.time()
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
if sampling_params and sampling_params.guided_decoding is not None: if sampling_params and sampling_params.guided_decoding is not None:
...@@ -118,11 +124,13 @@ class Request: ...@@ -118,11 +124,13 @@ class Request:
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest( structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params) \ sampling_params=request.sampling_params) \
if request.sampling_params else None, if request.sampling_params else None,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
priority=request.priority,
) )
def append_output_token_ids( def append_output_token_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment