Unverified Commit bd429f2b authored by Sebastian Schoennenbeck's avatar Sebastian Schoennenbeck Committed by GitHub
Browse files

[Core] Priority-based scheduling in async engine (#8850)

parent 18e60d7d
......@@ -420,6 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
......@@ -433,6 +434,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
......@@ -449,6 +451,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
......@@ -460,6 +463,9 @@ class _AsyncLLMEngine(LLMEngine):
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None:
arrival_time = time.time()
......@@ -479,6 +485,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
async def check_health_async(self) -> None:
......@@ -829,6 +836,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
......@@ -843,6 +851,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
......@@ -860,6 +869,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
......@@ -877,6 +887,11 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if (priority != 0
and not self.engine.scheduler_config.policy == "priority"):
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
......@@ -885,7 +900,9 @@ class AsyncLLMEngine:
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return stream.generator()
......@@ -896,7 +913,8 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
......@@ -913,6 +931,8 @@ class AsyncLLMEngine:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `RequestOutput` objects from the LLMEngine
......@@ -968,6 +988,7 @@ class AsyncLLMEngine:
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
......
......@@ -796,7 +796,7 @@ class LLMEngine:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority":
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment