Unverified Commit 14fdd527 authored by harrisonlimh's avatar harrisonlimh Committed by GitHub
Browse files

feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

parent f949ad57
...@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel): ...@@ -228,6 +228,8 @@ class CompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For customer metric labels # For customer metric labels
customer_labels: Optional[Dict[str, str]] = None customer_labels: Optional[Dict[str, str]] = None
...@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -543,6 +545,8 @@ class ChatCompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For PD disaggregation # For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
...@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel): ...@@ -644,6 +648,8 @@ class EmbeddingRequest(BaseModel):
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
class EmbeddingObject(BaseModel): class EmbeddingObject(BaseModel):
......
...@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -149,6 +149,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
priority=request.priority,
customer_labels=customer_labels, customer_labels=customer_labels,
) )
......
...@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -107,6 +107,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
priority=request.priority,
customer_labels=customer_labels, customer_labels=customer_labels,
) )
......
...@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -125,6 +125,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request = EmbeddingReqInput( adapted_request = EmbeddingReqInput(
**prompt_kwargs, **prompt_kwargs,
rid=request.rid, rid=request.rid,
priority=request.priority,
) )
return adapted_request, request return adapted_request, request
......
...@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput: ...@@ -570,6 +570,7 @@ class TokenizedGenerateReqInput:
token_ids_logprob: List[int] token_ids_logprob: List[int]
# Whether to stream output # Whether to stream output
stream: bool stream: bool
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
...@@ -656,6 +657,8 @@ class EmbeddingReqInput: ...@@ -656,6 +657,8 @@ class EmbeddingReqInput:
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# For cross-encoder requests # For cross-encoder requests
is_cross_encoder_request: bool = False is_cross_encoder_request: bool = False
# Priority for the request
priority: Optional[int] = None
# For background responses (OpenAI responses API) # For background responses (OpenAI responses API)
background: bool = False background: bool = False
...@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput: ...@@ -763,6 +766,8 @@ class TokenizedEmbeddingReqInput:
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# For dp balance # For dp balance
dp_balance_id: int = -1 dp_balance_id: int = -1
# Priority for the request
priority: Optional[int] = None
@dataclass @dataclass
......
...@@ -453,6 +453,7 @@ class Req: ...@@ -453,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None,
): ):
# Input and output info # Input and output info
...@@ -504,6 +505,7 @@ class Req: ...@@ -504,6 +505,7 @@ class Req:
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.priority = priority
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
...@@ -1517,12 +1519,35 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1517,12 +1519,35 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
idx = sorted_indices.pop() idx = sorted_indices.pop()
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0:
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = (
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
if server_args.disaggregation_mode == "decode": if server_args.disaggregation_mode == "decode":
req.offload_kv_cache( req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator self.req_to_token_pool, self.token_to_kv_pool_allocator
) )
if isinstance(self.tree_cache, ChunkCache): if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction # ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
...@@ -1547,26 +1572,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1547,26 +1572,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else: else:
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
req.reset_for_retract() # NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps
if len(retracted_reqs) == 0: self._evict_tree_cache_if_needed(num_tokens)
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
new_estimate_ratio = (
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
) / total_max_new_tokens
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio req.reset_for_retract()
def prepare_encoder_info_decode(self): def prepare_encoder_info_decode(self):
# Reset the encoder cached status # Reset the encoder cached status
......
...@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch ...@@ -28,6 +28,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
...@@ -82,10 +83,14 @@ class SchedulePolicy: ...@@ -82,10 +83,14 @@ class SchedulePolicy:
policy: str, policy: str,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool, enable_hierarchical_cache: bool,
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
): ):
self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache self.enable_hierarchical_cache = enable_hierarchical_cache
self.enable_priority_scheduling = enable_priority_scheduling
self.schedule_low_priority_values_first = schedule_low_priority_values_first
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache( self.waiting_queue_radix_tree = RadixCache(
...@@ -97,7 +102,10 @@ class SchedulePolicy: ...@@ -97,7 +102,10 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]) -> bool: def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS: if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS if self.enable_priority_scheduling:
SchedulePolicy._sort_by_priority_and_fcfs(
waiting_queue, self.schedule_low_priority_values_first
)
return False return False
policy = self._determine_active_policy(waiting_queue) policy = self._determine_active_policy(waiting_queue)
...@@ -120,12 +128,15 @@ class SchedulePolicy: ...@@ -120,12 +128,15 @@ class SchedulePolicy:
if policy == CacheAgnosticPolicy.FCFS: if policy == CacheAgnosticPolicy.FCFS:
pass pass
elif policy == CacheAgnosticPolicy.LOF: elif policy == CacheAgnosticPolicy.LOF:
SchedulePolicy._sort_by_longest_output(waiting_queue) SchedulePolicy._sort_by_longest_output(
waiting_queue,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
)
elif policy == CacheAgnosticPolicy.RANDOM: elif policy == CacheAgnosticPolicy.RANDOM:
SchedulePolicy._sort_randomly(waiting_queue) SchedulePolicy._sort_randomly(waiting_queue)
else: else:
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}") raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
return prefix_computed return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
...@@ -231,8 +242,22 @@ class SchedulePolicy: ...@@ -231,8 +242,22 @@ class SchedulePolicy:
) )
@staticmethod @staticmethod
def _sort_by_longest_output(waiting_queue: List[Req]) -> None: def _sort_by_longest_output(
"""Sorts the waiting queue based on the longest output (max_new_tokens).""" waiting_queue: List[Req],
enable_priority_scheduling: bool,
schedule_low_priority_values_first: bool,
) -> None:
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
if enable_priority_scheduling:
if schedule_low_priority_values_first:
waiting_queue.sort(
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
)
else:
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
@staticmethod @staticmethod
...@@ -240,6 +265,16 @@ class SchedulePolicy: ...@@ -240,6 +265,16 @@ class SchedulePolicy:
"""Shuffles the waiting queue randomly.""" """Shuffles the waiting queue randomly."""
random.shuffle(waiting_queue) random.shuffle(waiting_queue)
@staticmethod
def _sort_by_priority_and_fcfs(
waiting_queue: List[Req], schedule_low_priority_values_first: bool
) -> None:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if schedule_low_priority_values_first:
waiting_queue.sort(key=lambda x: (x.priority, x.queue_time_start))
else:
waiting_queue.sort(key=lambda x: (-x.priority, x.queue_time_start))
@staticmethod @staticmethod
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
for child in cur_node.children.values(): for child in cur_node.children.values():
...@@ -279,6 +314,7 @@ class PrefillAdder: ...@@ -279,6 +314,7 @@ class PrefillAdder:
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
): ):
self.page_size = page_size self.page_size = page_size
self.tree_cache = tree_cache self.tree_cache = tree_cache
...@@ -295,6 +331,7 @@ class PrefillAdder: ...@@ -295,6 +331,7 @@ class PrefillAdder:
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.preempt_list = []
self.new_chunked_req = None self.new_chunked_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment # TODO(lsyin): report the real input tokens excluding page alignment
...@@ -303,11 +340,7 @@ class PrefillAdder: ...@@ -303,11 +340,7 @@ class PrefillAdder:
if running_batch is not None: if running_batch is not None:
self.rem_total_token_offset += sum( self.rem_total_token_offset += sum(
[ [
min( self._get_running_request_total_token_offset(r)
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
...@@ -316,6 +349,19 @@ class PrefillAdder: ...@@ -316,6 +349,19 @@ class PrefillAdder:
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
) )
self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
(req.sampling_params.max_new_tokens - len(req.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
)
@property @property
def rem_total_tokens(self): def rem_total_tokens(self):
if self.is_hybrid: if self.is_hybrid:
...@@ -568,3 +614,61 @@ class PrefillAdder: ...@@ -568,3 +614,61 @@ class PrefillAdder:
self._update_prefill_budget(prefix_len, trunc_len, 0) self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state() return self.budget_state()
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
"""
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
Returns True if preemption was committed, and the new request can be scheduled.
"""
# Iterate running requests to find preemptible requests
if server_args.schedule_low_priority_values_first:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (-x.priority, -x.queue_time_start),
)
else:
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority, -x.queue_time_start),
)
preemptible_reqs = []
min_tokens_to_remove = (
req.extend_input_len
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
- self.rem_total_tokens
)
for running_req in sorted_running_reqs:
if running_req in self.preempt_list:
continue
# Priority difference needs to meet the threshold to be preemptible.
priority_diff = req.priority - running_req.priority
if server_args.schedule_low_priority_values_first:
priority_diff *= -1
if priority_diff > self.priority_scheduling_preemption_threshold:
preemptible_reqs.append(running_req)
min_tokens_to_remove -= self._get_running_request_total_token_offset(
running_req
)
# Check max token count limit can be met
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
return False
# Preempt running requests. Release allocated resources for immediate usage.
preemptible_reqs = set(preemptible_reqs)
keep_indices = []
release_counter = 0
for i, running_req in enumerate(self.running_batch.reqs):
if running_req in preemptible_reqs:
self.rem_total_token_offset -= (
self._get_running_request_total_token_offset(req)
)
release_counter += 1
self.running_batch.release_req(
i, len(self.running_batch.reqs) - release_counter, server_args
)
else:
keep_indices.append(i)
self.running_batch.filter_batch(keep_indices=keep_indices)
self.preempt_list.extend(preemptible_reqs)
return True
...@@ -243,6 +243,13 @@ class Scheduler( ...@@ -243,6 +243,13 @@ class Scheduler(
self.pp_size = server_args.pp_size self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy self.schedule_policy = server_args.schedule_policy
self.enable_priority_scheduling = server_args.enable_priority_scheduling
self.schedule_low_priority_values_first = (
server_args.schedule_low_priority_values_first
)
self.priority_scheduling_preemption_threshold = (
server_args.priority_scheduling_preemption_threshold
)
self.enable_lora = server_args.enable_lora self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
...@@ -487,7 +494,12 @@ class Scheduler( ...@@ -487,7 +494,12 @@ class Scheduler(
self.schedule_policy, self.schedule_policy,
self.tree_cache, self.tree_cache,
self.enable_hierarchical_cache, self.enable_hierarchical_cache,
self.enable_priority_scheduling,
self.schedule_low_priority_values_first,
) )
# Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
...@@ -1150,20 +1162,6 @@ class Scheduler( ...@@ -1150,20 +1162,6 @@ class Scheduler(
self.return_health_check_ct += 1 self.return_health_check_ct += 1
continue continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request. # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWrapper): if isinstance(recv_req, MultiTokenizerWrapper):
worker_id = recv_req.worker_id worker_id = recv_req.worker_id
...@@ -1233,6 +1231,7 @@ class Scheduler( ...@@ -1233,6 +1231,7 @@ class Scheduler(
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank, data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size, vocab_size=self.model_config.vocab_size,
priority=recv_req.priority,
metrics_collector=( metrics_collector=(
self.metrics_collector if self.enable_metrics else None self.metrics_collector if self.enable_metrics else None
), ),
...@@ -1382,6 +1381,9 @@ class Scheduler( ...@@ -1382,6 +1381,9 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req) self.disagg_decode_prealloc_queue.add(req)
else: else:
self._set_or_validate_priority(req)
if self._abort_on_queued_limit(req):
return
self._prefetch_kvcache(req) self._prefetch_kvcache(req)
self.waiting_queue.append(req) self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True) trace_slice_end("process req", req.rid, auto_next_anon=True)
...@@ -1408,7 +1410,70 @@ class Scheduler( ...@@ -1408,7 +1410,70 @@ class Scheduler(
# If this is a decode server, we put the request to the decode pending prealloc queue # If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted) self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else: else:
self.waiting_queue.extend(reqs) for req in reqs:
self._set_or_validate_priority(req)
if not self._abort_on_queued_limit(req):
self.waiting_queue.append(req)
def _set_or_validate_priority(self, req: Req):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None:
if self.schedule_low_priority_values_first:
req.priority = sys.maxsize
else:
req.priority = -sys.maxsize - 1
elif not self.enable_priority_scheduling and req.priority is not None:
abort_req = AbortReq(
req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
if (
self.max_queued_requests is None
or len(self.waiting_queue) + 1 <= self.max_queued_requests
):
return False
# Reject the incoming request by default.
req_to_abort = recv_req
message = "The request queue is full."
if self.enable_priority_scheduling:
# With priority scheduling, consider aboritng an existing request based on the priority.
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
direction = 1 if self.schedule_low_priority_values_first else -1
key_fn = lambda item: (
direction * item[1].priority,
item[1].queue_time_start,
)
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
abort_existing_req = (
direction * recv_req.priority < direction * candidate_req.priority
)
if abort_existing_req:
self.waiting_queue.pop(idx)
req_to_abort = candidate_req
message = "The request is aborted by a higher priority request."
self.send_to_tokenizer.send_pyobj(
AbortReq(
req_to_abort.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": message,
},
)
)
return req_to_abort.rid == recv_req.rid
def handle_embedding_request( def handle_embedding_request(
self, self,
...@@ -1420,6 +1485,7 @@ class Scheduler( ...@@ -1420,6 +1485,7 @@ class Scheduler(
recv_req.input_ids, recv_req.input_ids,
recv_req.sampling_params, recv_req.sampling_params,
token_type_ids=recv_req.token_type_ids, token_type_ids=recv_req.token_type_ids,
priority=recv_req.priority,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -1680,6 +1746,10 @@ class Scheduler( ...@@ -1680,6 +1746,10 @@ class Scheduler(
if self.grammar_queue: if self.grammar_queue:
self.move_ready_grammar_requests() self.move_ready_grammar_requests()
if self.try_preemption:
# Reset batch_is_full to try preemption with a prefill adder.
self.running_batch.batch_is_full = False
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0 self.running_batch.batch_is_full or len(self.waiting_queue) == 0
...@@ -1692,7 +1762,11 @@ class Scheduler( ...@@ -1692,7 +1762,11 @@ class Scheduler(
# as the space for the chunked request has just been released. # as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict. # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak. # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req: if (
self.get_num_allocatable_reqs(running_bs) <= 0
and not self.chunked_req
and not self.try_preemption
):
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
return None return None
...@@ -1712,6 +1786,7 @@ class Scheduler( ...@@ -1712,6 +1786,7 @@ class Scheduler(
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
) )
if self.chunked_req is not None: if self.chunked_req is not None:
...@@ -1732,15 +1807,19 @@ class Scheduler( ...@@ -1732,15 +1807,19 @@ class Scheduler(
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break break
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs): if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory, # In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size. # so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size(): if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
if self.running_batch.batch_is_full:
if not self.try_preemption:
break
if not adder.preempt_to_schedule(req, self.server_args):
break break
if self.enable_hicache_storage: if self.enable_hicache_storage:
...@@ -1777,6 +1856,8 @@ class Scheduler( ...@@ -1777,6 +1856,8 @@ class Scheduler(
self.waiting_queue = [ self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if adder.preempt_list:
self._extend_requests_to_queue(adder.preempt_list)
if adder.new_chunked_req is not None: if adder.new_chunked_req is not None:
assert self.chunked_req is None assert self.chunked_req is None
......
...@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -738,6 +738,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
custom_logit_processor=obj.custom_logit_processor, custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states, return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank, data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -747,6 +748,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs, mm_inputs,
token_type_ids, token_type_ids,
sampling_params, sampling_params,
priority=obj.priority,
) )
return tokenized_obj return tokenized_obj
......
...@@ -149,8 +149,8 @@ class TpModelWorker: ...@@ -149,8 +149,8 @@ class TpModelWorker:
assert self.max_running_requests > 0, "max_running_request is zero" assert self.max_running_requests > 0, "max_running_request is zero"
self.max_queued_requests = server_args.max_queued_requests self.max_queued_requests = server_args.max_queued_requests
assert ( assert (
self.max_queued_requests > 0 self.max_queued_requests is None or self.max_queued_requests >= 1
), "max_queued_requests is zero. We need to be at least 1 to schedule a request." ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
self.max_req_len = min( self.max_req_len = min(
self.model_config.context_len - 1, self.model_config.context_len - 1,
self.max_total_num_tokens - 1, self.max_total_num_tokens - 1,
......
...@@ -172,11 +172,14 @@ class ServerArgs: ...@@ -172,11 +172,14 @@ class ServerArgs:
# Memory and scheduling # Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = sys.maxsize max_queued_requests: Optional[int] = None
max_total_tokens: Optional[int] = None max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384 max_prefill_tokens: int = 16384
schedule_policy: str = "fcfs" schedule_policy: str = "fcfs"
enable_priority_scheduling: bool = False
schedule_low_priority_values_first: bool = False
priority_scheduling_preemption_threshold: int = 10
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
page_size: Optional[int] = None page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None hybrid_kvcache_ratio: Optional[float] = None
...@@ -1166,6 +1169,24 @@ class ServerArgs: ...@@ -1166,6 +1169,24 @@ class ServerArgs:
choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"], choices=["lpm", "random", "fcfs", "dfs-weight", "lof", "priority"],
help="The scheduling policy of the requests.", help="The scheduling policy of the requests.",
) )
parser.add_argument(
"--enable-priority-scheduling",
action="store_true",
default=ServerArgs.enable_priority_scheduling,
help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.",
)
parser.add_argument(
"--schedule-low-priority-values-first",
action="store_true",
default=ServerArgs.schedule_low_priority_values_first,
help="If specified with --enable-priority-scheduling, the scheduler will schedule requests with lower priority integer values first.",
)
parser.add_argument(
"--priority-scheduling-preemption-threshold",
type=int,
default=ServerArgs.priority_scheduling_preemption_threshold,
help="Minimum difference in priorities for an incoming request to have to preempt running request(s).",
)
parser.add_argument( parser.add_argument(
"--schedule-conservativeness", "--schedule-conservativeness",
type=float, type=float,
...@@ -2455,6 +2476,13 @@ class ServerArgs: ...@@ -2455,6 +2476,13 @@ class ServerArgs:
"--generation-tokens-buckets", self.generation_tokens_buckets "--generation-tokens-buckets", self.generation_tokens_buckets
) )
# Check scheduling policy
if self.enable_priority_scheduling:
assert self.schedule_policy in [
"fcfs",
"lof",
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
def check_lora_server_args(self): def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......
...@@ -17,7 +17,7 @@ from dataclasses import dataclass ...@@ -17,7 +17,7 @@ from dataclasses import dataclass
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
...@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests( ...@@ -1390,6 +1390,41 @@ async def send_concurrent_generate_requests(
return await asyncio.gather(*tasks) return await asyncio.gather(*tasks)
async def send_concurrent_generate_requests_with_custom_params(
base_url: str,
custom_params: List[dict[str, Any]],
) -> Tuple[int, Any]:
"""Sends generate request concurrently with custom parameters and returns status code and response json tuple. Max concurrency is num_requests."""
base_payload = {
"text": """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
""",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
}
async def async_generate_with_priority(req):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{base_url}/generate",
json=req,
) as response:
resp_json = await response.json()
return (response.status, resp_json)
tasks = []
for c in custom_params:
req = base_payload.copy()
req.update(c)
tasks.append(asyncio.create_task(async_generate_with_priority(req)))
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase): class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method): def _callTestMethod(self, method):
max_retry = int( max_retry = int(
......
...@@ -95,6 +95,7 @@ suites = { ...@@ -95,6 +95,7 @@ suites = {
TestFile("test_original_logprobs.py", 200), TestFile("test_original_logprobs.py", 200),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
TestFile("test_priority_scheduling.py", 100),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105), TestFile("test_radix_attention.py", 105),
TestFile("test_regex_constrained.py", 64), TestFile("test_regex_constrained.py", 64),
......
import asyncio
import os
import re
import unittest
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
CustomTestCase,
popen_launch_server,
send_concurrent_generate_requests_with_custom_params,
)
class TestPriorityScheduling(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 1
"1",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(1, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_request_ordering_validation(self):
"""Verify pending requests are ordered by priority and received timestamp."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first
{"priority": 1}, # third
{"priority": 1}, # fourth
{"priority": 2}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[3] < e2e_latencies[1] < e2e_latencies[2]
def test_priority_scheduling_existing_requests_abortion_validation(self):
"""Verify lower priority requests are aborted when incoming requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 1,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 2}, # aborted by request 5
{"priority": 3}, # aborted by request 6
{"priority": 4}, # aborted by request 7
{"priority": 5}, # fourth
{"priority": 6}, # third
{"priority": 7}, # second
],
)
)
expected_status_and_error_messages = [
(200, None),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(503, "The request is aborted by a higher priority request."),
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[6] < e2e_latencies[5] < e2e_latencies[4]
def test_priority_scheduling_incoming_request_rejection_validation(self):
"""Verify incoming requests are rejected when existing requests have higher priority"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 7,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first and holds the running queue capacity
{"priority": 6}, # second
{"priority": 5}, # third
{"priority": 4}, # fourth
{"priority": 3}, # rejected
{"priority": 2}, # rejected
{"priority": 1}, # rejected
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
(503, "The request queue is full."),
(503, "The request queue is full."),
(503, "The request queue is full."),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1] < e2e_latencies[2] < e2e_latencies[3]
def test_priority_scheduling_preemption_meeting_threshold_validation(self):
"""Verify running requests are preempted by requests with priorities meeting the preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
}, # starts being processed first then preempted or pushed by later requests, and finishes last.
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # scheduled after the third request, and finishes second.
{
"priority": 20,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first.
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[2] < e2e_latencies[1] < e2e_latencies[0]
def test_priority_scheduling_preemption_below_threshold_validation(self):
"""Verify running requests are not preempted by requests with priorities below preemption threshold"""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 0,
"sampling_params": {"max_new_tokens": 10000},
},
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
},
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
]
e2e_latencies = []
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)
assert e2e_latencies[0] < e2e_latencies[1]
class TestPrioritySchedulingMultipleRunningRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 2
"2",
"--max-queued-requests", # Enforce max queued request number is 3
"3",
"--enable-priority-scheduling", # Enable priority scheduling
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
_verify_max_running_requests_and_max_queued_request_validation(2, 3)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_priority_scheduling_with_multiple_running_requests_preemption(self):
"""Verify preempting a subset of running requests is safe."""
responses = asyncio.run(
send_concurrent_generate_requests_with_custom_params(
self.base_url,
[
{
"priority": 10,
"sampling_params": {"max_new_tokens": 10000},
}, # finishes first
{
"priority": 5,
"sampling_params": {"max_new_tokens": 10000},
}, # preempted by fourth request, then finishes third
{
"priority": 15,
"sampling_params": {"max_new_tokens": 10000},
}, # preempt the first request
],
)
)
expected_status_and_error_messages = [
(200, None),
(200, None),
(200, None),
(200, None),
]
_verify_genereate_responses(responses, expected_status_and_error_messages, [])
def _verify_genereate_responses(
responses: Tuple[int, Any, float],
expected_code_and_error_message: Tuple[int, Any],
e2e_latencies: List[Optional[float]],
):
"""
Verify generate response results are as expected based on status code and response json object content.
In addition, collects e2e latency info to verify scheduling and processing ordering.
"""
for got, expected in zip(responses, expected_code_and_error_message):
got_status, got_json = got
expected_status, expected_err_msg = expected
# Check status code is as expected
assert got_status == expected_status
# Check error message content or fields' existence based on status code
if got_status != 200:
assert got_json["object"] == "error"
assert got_json["message"] == expected_err_msg
else:
assert "object" not in got_json
assert "message" not in got_json
# Collect e2e latencies for scheduling validation
e2e_latencies.append(
got_json["meta_info"]["e2e_latency"] if got_status == 200 else None
)
def _verify_max_running_requests_and_max_queued_request_validation(
max_running_requests: int, max_queued_requests: int
):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
with open(STDERR_FILENAME) as lines:
for line in lines:
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
if rr_match:
assert int(rr_match.group(1)) <= max_running_requests
if qr_match:
assert int(qr_match.group(1)) <= max_queued_requests
if __name__ == "__main__":
unittest.main()
...@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase): ...@@ -65,9 +65,8 @@ class TestMaxQueuedRequests(CustomTestCase):
send_concurrent_generate_requests(self.base_url, num_requests=10) send_concurrent_generate_requests(self.base_url, num_requests=10)
) )
assert 200 in status_codes expected_status_codes = [200, 200, 503, 503, 503, 503, 503, 503, 503, 503]
assert 503 in status_codes assert status_codes == expected_status_codes
assert all(status_code in [200, 503] for status_code in status_codes)
def test_max_running_requests_and_max_queued_request_validation(self): def test_max_running_requests_and_max_queued_request_validation(self):
"""Verify running request and queued request numbers based on server logs.""" """Verify running request and queued request numbers based on server logs."""
......
...@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase): ...@@ -18,13 +18,21 @@ class TestSchedulePolicy(CustomTestCase):
def test_init_with_cache_aware_policy(self): def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy( policy = SchedulePolicy(
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True policy="lpm",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
) )
self.assertEqual(policy.policy, CacheAwarePolicy.LPM) self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self): def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy( policy = SchedulePolicy(
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True policy="fcfs",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
) )
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
...@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase): ...@@ -34,12 +42,18 @@ class TestSchedulePolicy(CustomTestCase):
policy="invalid", policy="invalid",
tree_cache=self.tree_cache, tree_cache=self.tree_cache,
enable_hierarchical_cache=True, enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
) )
def test_init_with_disabled_cache(self): def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1) disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
policy = SchedulePolicy( policy = SchedulePolicy(
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True policy="lpm",
tree_cache=disabled_tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
) )
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
...@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase): ...@@ -52,7 +66,11 @@ class TestSchedulePolicy(CustomTestCase):
] ]
policy = SchedulePolicy( policy = SchedulePolicy(
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
) )
policy.calc_priority(waiting_queue) policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order # Check if FCFS keeps the original order
...@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase): ...@@ -60,6 +78,126 @@ class TestSchedulePolicy(CustomTestCase):
self.assertEqual(waiting_queue[1].rid, 3) self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2) self.assertEqual(waiting_queue[2].rid, 2)
def test_calc_priority_priority_enabled_fcfs_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0
waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1
waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0
policy = SchedulePolicy(
policy="fcfs",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10)),
Req(2, "a", [1], SamplingParams(max_new_tokens=100)),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=False,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=0),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=0),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=False,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first(
self,
):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0),
Req(3, "a b c", [1, 2, 3], SamplingParams(max_new_tokens=10), priority=1),
Req(2, "a", [1], SamplingParams(max_new_tokens=100), priority=1),
]
policy = SchedulePolicy(
policy="lof",
tree_cache=tree_cache,
enable_hierarchical_cache=True,
enable_priority_scheduling=True,
schedule_low_priority_values_first=True,
)
policy.calc_priority(waiting_queue)
# Check if priority enabled fcfs ordering is applied.
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 2)
self.assertEqual(waiting_queue[2].rid, 3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment