Unverified Commit 14fdd527 authored by harrisonlimh's avatar harrisonlimh Committed by GitHub
Browse files

feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

parent f949ad57
......@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
......@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
......@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel):
# The request id.
rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
class EmbeddingObject(BaseModel):
......
......@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
priority=request.priority,
customer_labels=customer_labels,
)
......
......@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
priority=request.priority,
customer_labels=customer_labels,
)
......
......@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request = EmbeddingReqInput(
**prompt_kwargs,
rid=request.rid,
priority=request.priority,
)
return adapted_request, request
......
......@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput:
token_ids_logprob: List[int]
# Whether to stream output
stream: bool
# Whether to return hidden states
return_hidden_states: bool = False
......@@ -656,6 +657,8 @@ class EmbeddingReqInput:
modalities: Optional[List[str]] = None
# For cross-encoder requests
is_cross_encoder_request: bool = False
# Priority for the request
priority: Optional[int] = None
# For background responses (OpenAI responses API)
background: bool = False
......@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput:
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request
priority: Optional[int] = None
@dataclass
......
......@@ -453,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
# Input and output info
......@@ -504,6 +505,7 @@ class Req:
self.stream = stream
self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size
self.priority = priority
# For incremental decoding
# ----- | --------- read_ids -------|
......@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
req.reset_for_retract()
self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0:
# Corner case: only one request left
......@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return retracted_reqs, new_estimate_ratio
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
......
......@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
......@@ -82,10 +83,14 @@ class SchedulePolicy:
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool,
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache
self.enable_priority_scheduling = enable_priority_scheduling
self.schedule_low_priority_values_first = schedule_low_priority_values_first
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
......@@ -97,7 +102,10 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
if self.enable_priority_scheduling:
SchedulePolicy._sort_by_priority_and_fcfs(
waiting_queue, self.schedule_low_priority_values_first
)
return False
policy = self._determine_active_policy(waiting_queue)
......@@ -120,12 +128,15 @@ class SchedulePolicy:
if policy == CacheAgnosticPolicy.FCFS:
pass
elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue)
SchedulePolicy._sort_by_longest_output(
waiting_queue,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue)
else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
......@@ -231,15 +242,39 @@ class SchedulePolicy:
)
@staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
def _sort_by_longest_output(
waiting_queue: List[Req],
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
if enable_priority_scheduling:
if schedule_low_priority_values_first:
waiting_queue.sort(
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
@staticmethod
def _sort_randomly(waiting_queue: List[Req]) -> None:
"""Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue)
@staticmethod
def _sort_by_priority_and_fcfs(
waiting_queue: List[Req], schedule_low_priority_values_first: bool
) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
@staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values():
......@@ -279,6 +314,7 @@ class PrefillAdder:
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
self.page_size = page_size
self.tree_cache = tree_cache
......@@ -295,6 +331,7 @@ class PrefillAdder:
self.req_states = None
self.can_run_list = []
self.preempt_list = []
self.new_chunked_req = None
self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
......@@ -303,11 +340,7 @@ class PrefillAdder:
if running_batch is not None:
self.rem_total_token_offset += sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
self._get_running_request_total_token_offset(r)
for r in running_batch.reqs
]
)
......@@ -316,6 +349,19 @@ class PrefillAdder:
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
(req.sampling_params.max_new_tokens - len(req.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
)
@property
def rem_total_tokens(self):
if self.is_hybrid:
......@@ -568,3 +614,61 @@ class PrefillAdder:
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
"""
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
Returns True if preemption was committed, and the new request can be scheduled.
"""
# Iterate running requests to find preemptible requests
if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start),
)
else:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start),
)
preemptible_reqs = []
min_tokens_to_remove = (
req.extend_input_len
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
- self.rem_total_tokens
)
for running_req in sorted_running_reqs:
if running_req in self.preempt_list:
continue
# Priority difference needs to meet the threshold to be preemptible.
priority_diff = req.priority - running_req.priority
if server_args.schedule_low_priority_values_first:
priority_diff *= -1
if priority_diff > self.priority_scheduling_preemption_threshold:
preemptible_reqs.append(running_req)
min_tokens_to_remove -= self._get_running_request_total_token_offset(
running_req
)
# Check max token count limit can be met
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
return False
# Preempt running requests. Release allocated resources for immediate usage.
preemptible_reqs = set(preemptible_reqs)
keep_indices = []
release_counter = 0
for i, running_req in enumerate(self.running_batch.reqs):
if running_req in preemptible_reqs:
self.rem_total_token_offset -= (
self._get_running_request_total_token_offset(req)
)
release_counter += 1
self.running_batch.release_req(
i, len(self.running_batch.reqs) - release_counter, server_args
)
else:
keep_indices.append(i)
self.running_batch.filter_batch(keep_indices=keep_indices)
self.preempt_list.extend(preemptible_reqs)
return True
......@@ -243,6 +243,13 @@ class Scheduler(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.enable_priority_scheduling = server_args.enable_priority_scheduling
self.schedule_low_priority_values_first = (
server_args.schedule_low_priority_values_first
)
self.priority_scheduling_preemption_threshold = (
server_args.priority_scheduling_preemption_threshold
)
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
......@@ -487,7 +494,12 @@ class Scheduler(
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
# Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
......@@ -1150,20 +1162,6 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id
......@@ -1233,6 +1231,7 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
......@@ -1382,6 +1381,9 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
......@@ -1408,7 +1410,70 @@ class Scheduler(
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else:
self.waiting_queue.extend(reqs)
for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None:
if self.schedule_low_priority_values_first:
req.priority = sys.maxsize
else:
req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq(
req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
if (
self.max_queued_requests is None
or len(self.waiting_queue) + 1 <= self.max_queued_requests
):
return False
# Reject the incoming request by default.
req_to_abort = recv_req
message = "The request queue is full."
if self.enable_priority_scheduling:
# With priority scheduling, consider aboritng an existing request based on the priority.
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
direction * recv_req.priority < direction * candidate_req.priority
)
if abort_existing_req:
self.waiting_queue.pop(idx)
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
AbortReq(
req_to_abort.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message,
},
)
)
return req_to_abort.rid == recv_req.rid
def handle_embedding_request(
self,
......@@ -1420,6 +1485,7 @@ class Scheduler(
recv_req.input_ids,
recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
)
req.tokenizer = self.tokenizer
......@@ -1680,6 +1746,10 @@ class Scheduler(
if self.grammar_queue:
self.move_ready_grammar_requests()
if self.try_preemption:
# Reset batch_is_full to try preemption with a prefill adder.
self.running_batch.batch_is_full = False
# Handle the cases where prefill is not allowed
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
......@@ -1692,7 +1762,11 @@ class Scheduler(
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
if (
self.get_num_allocatable_reqs(running_bs) <= 0
and not self.chunked_req
and not self.try_preemption
):
self.running_batch.batch_is_full = True
return None
......@@ -1712,6 +1786,7 @@ class Scheduler(
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
if self.chunked_req is not None:
......@@ -1732,15 +1807,19 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True
if self.running_batch.batch_is_full:
if not self.try_preemption:
break
if not adder.preempt_to_schedule(req, self.server_args):
break
if self.enable_hicache_storage:
......@@ -1777,6 +1856,8 @@ class Scheduler(
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
if adder.new_chunked_req is not None:
assert self.chunked_req is None
......
......@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
......@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs,
token_type_ids,
sampling_params,
priority=obj.priority,
)
return tokenized_obj
......
......@@ -149,8 +149,8 @@ class TpModelWorker:
assert self.max_running_requests > 0, "max_running_request is zero"
self.max_queued_requests = server_args.max_queued_requests
assert (
self.max_queued_requests > 0
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
self.max_queued_requests is None or self.max_queued_requests >= 1
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
self.max_req_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
......
......@@ -172,11 +172,14 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = sys.maxsize
max_queued_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "fcfs"
enable_priority_scheduling: bool = False
schedule_low_priority_values_first: bool = False
priority_scheduling_preemption_threshold: int = 10
schedule_conservativeness: float = 1.0
page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None
......@@ -1166,6 +1169,24 @@ class ServerArgs:
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
help="The scheduling policy of the requests.",
)
parser.add_argument(
"--enable-priority-scheduling",
action="store_true",
default=ServerArgs.enable_priority_scheduling,
help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.",
)
parser.add_argument(
"--schedule-low-priority-values-first",
action="store_true",
default=ServerArgs.schedule_low_priority_values_first,
help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.",
)
parser.add_argument(
"--priority-scheduling-preemption-threshold",
type=int,
default=ServerArgs.priority_scheduling_preemption_threshold,
help="Minimum difference in priorities for an incoming request to have to preempt running request(s).",
)
parser.add_argument(
"--schedule-conservativeness",
type=float,
......@@ -2455,6 +2476,13 @@ class ServerArgs:
"--generation-tokens-buckets", self.generation_tokens_buckets
)
# Check scheduling policy
if self.enable_priority_scheduling:
assert self.schedule_policy in [
"fcfs",
"lof",
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......
......@@ -17,7 +17,7 @@ from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import aiohttp
import numpy as np
......@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests(
return await asyncio.gather(*tasks)
async def send_concurrent_generate_requests_with_custom_params(
base_url: str,
custom_params: List[dict[str, Any]],
) -> Tuple[int, Any]:
"""Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""
base_payload = {
"text": """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
""",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
}
async def async_generate_with_priority(req):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{base_url}/generate",
json=req,
) as response:
resp_json = await response.json()
return (response.status, resp_json)
tasks = []
for c in custom_params:
req = base_payload.copy()
req.update(c)
tasks.append(asyncio.create_task(async_generate_with_priority(req)))
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
max_retry = int(
......
......@@ -95,6 +95,7 @@ suites = {
TestFile("test_original_logprobs.py", 200),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_priority_scheduling.py", 100),
TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105),
TestFile("test_regex_constrained.py", 64),
......
import asyncio
import os
import re
import unittest
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
CustomTestCase,
popen_launch_server,
send_concurrent_generate_requests_with_custom_params,
)
class TestPriorityScheduling(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 1
"1",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(1, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_request_ordering_validation(self):
"""Verify pending requests are ordered by priority and received timestamp."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first
{"priority": 1}, # third
{"priority": 1}, # fourth
{"priority": 2}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2]
def test_priority_scheduling_existing_requests_abortion_validation(self):
"""Verify lower priority requests are aborted when incoming requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 1,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 2}, # aborted by request 5
{"priority": 3}, # aborted by request 6
{"priority": 4}, # aborted by request 7
{"priority": 5}, # fourth
{"priority": 6}, # third
{"priority": 7}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4]
def test_priority_scheduling_incoming_request_rejection_validation(self):
"""Verify incoming requests are rejected when existing requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 7,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 6}, # second
{"priority": 5}, # third
{"priority": 4}, # fourth
{"priority": 3}, # rejected
{"priority": 2}, # rejected
{"priority": 1}, # rejected
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
(503, "The request queue is full."),
(503, "The request queue is full."),
(503, "The request queue is full."),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3]
def test_priority_scheduling_preemption_meeting_threshold_validation(self):
"""Verify running requests are preempted by requests with priorities meeting the preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first then preempted or pushed by later requests, and finishes last.
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # scheduled after the third request, and finishes second.
{
"priority": 20,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first.
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0]
def test_priority_scheduling_preemption_below_threshold_validation(self):
"""Verify running requests are not preempted by requests with priorities below preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
},
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
},
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1]
class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 2
"2",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(2, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_with_multiple_running_requests_preemption(self):
"""Verify preempting a subset of running requests is safe."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
}, # preempted by fourth request, then finishes third
{
"priority": 15,
"sampling_params": {"max_new_tokens": 10000},
}, # preempt the first request
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
_verify_genereate_responses(responses, expected_status_and_error_messages, [])
def _verify_genereate_responses(
responses: Tuple[int, Any, float],
expected_code_and_error_message: Tuple[int, Any],
e2e_latencies: List[Optional[float]],
):
"""
Verify generate response results are as expected based on status code and response json object content.
In addition, collects e2e latency info to verify scheduling and processing ordering.
"""
for got, expected in zip(responses, expected_code_and_error_message):
got_status, got_json = got
expected_status, expected_err_msg = expected
# Check status code is as expected
assert got_status == expected_status
# Check error message content or fields' existence based on status code
if got_status != 200:
assert got_json["object"] == "error"
assert got_json["message"] == expected_err_msg
else:
assert "object" not in got_json
assert "message" not in got_json
# Collect e2e latencies for scheduling validation
e2e_latencies.append(
got_json["meta_info"]["e2e_latency"] if got_status == 200 else None
)
def _verify_max_running_requests_and_max_queued_request_validation(
max_running_requests: int, max_queued_requests: int
):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
with open(STDERR_FILENAME) as lines:
for line in lines:
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
if rr_match:
assert int(rr_match.group(1)) <= max_running_requests
if qr_match:
assert int(qr_match.group(1)) <= max_queued_requests
if __name__ == "__main__":
unittest.main()
......@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase):
send_concurrent_generate_requests(self.base_url, num_requests=10)
)
assert 200 in status_codes
assert 503 in status_codes
assert all(status_code in [200, 503] for status_code in status_codes)
expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
assert status_codes == expected_status_codes
def test_max_running_requests_and_max_queued_request_validation(self):
"""Verify running request and queued request numbers based on server logs."""
......
......@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
......@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
policy="invalid",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
policy = SchedulePolicy(
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
policy="lpm",
tree_cache=disabled_tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
......@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
]
policy = SchedulePolicy(
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
......@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment