Unverified Commit 35bd2151 authored by Sebastian Schoennenbeck's avatar Sebastian Schoennenbeck Committed by GitHub
Browse files

[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)

parent 1fe0a426
...@@ -1043,6 +1043,7 @@ class AsyncLLMEngine: ...@@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -1057,6 +1058,8 @@ class AsyncLLMEngine: ...@@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
...@@ -1109,6 +1112,7 @@ class AsyncLLMEngine: ...@@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
......
...@@ -30,6 +30,7 @@ class RPCProcessRequest: ...@@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
@overload # DEPRECATED @overload # DEPRECATED
def __init__( def __init__(
...@@ -41,6 +42,7 @@ class RPCProcessRequest: ...@@ -41,6 +42,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
... ...
...@@ -53,6 +55,7 @@ class RPCProcessRequest: ...@@ -53,6 +55,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
... ...
...@@ -68,6 +71,7 @@ class RPCProcessRequest: ...@@ -68,6 +71,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
...@@ -84,6 +88,7 @@ class RPCProcessRequest: ...@@ -84,6 +88,7 @@ class RPCProcessRequest:
self.lora_request = lora_request self.lora_request = lora_request
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass @dataclass
......
...@@ -380,6 +380,7 @@ class MQLLMEngineClient: ...@@ -380,6 +380,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
... ...
...@@ -392,6 +393,7 @@ class MQLLMEngineClient: ...@@ -392,6 +393,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
... ...
...@@ -407,6 +409,7 @@ class MQLLMEngineClient: ...@@ -407,6 +409,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -425,6 +428,9 @@ class MQLLMEngineClient: ...@@ -425,6 +428,9 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
""" """
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
...@@ -433,7 +439,7 @@ class MQLLMEngineClient: ...@@ -433,7 +439,7 @@ class MQLLMEngineClient:
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request) prompt_adapter_request, priority)
@overload # DEPRECATED @overload # DEPRECATED
def encode( def encode(
...@@ -444,6 +450,7 @@ class MQLLMEngineClient: ...@@ -444,6 +450,7 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
... ...
...@@ -455,6 +462,7 @@ class MQLLMEngineClient: ...@@ -455,6 +462,7 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
... ...
...@@ -469,6 +477,7 @@ class MQLLMEngineClient: ...@@ -469,6 +477,7 @@ class MQLLMEngineClient:
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...@@ -496,7 +505,7 @@ class MQLLMEngineClient: ...@@ -496,7 +505,7 @@ class MQLLMEngineClient:
and request_id is not None) and request_id is not None)
return self._process_request(prompt, pooling_params, request_id, return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers) lora_request, trace_headers, priority)
async def _process_request( async def _process_request(
self, self,
...@@ -505,7 +514,8 @@ class MQLLMEngineClient: ...@@ -505,7 +514,8 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]: EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
...@@ -550,7 +560,9 @@ class MQLLMEngineClient: ...@@ -550,7 +560,9 @@ class MQLLMEngineClient:
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)) prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine. # 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes, parts = (request_bytes,
......
...@@ -40,7 +40,8 @@ class EngineClient(Protocol): ...@@ -40,7 +40,8 @@ class EngineClient(Protocol):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generate outputs for a request."""
... ...
...@@ -52,6 +53,7 @@ class EngineClient(Protocol): ...@@ -52,6 +53,7 @@ class EngineClient(Protocol):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...
......
...@@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
...@@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel): ...@@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: end-embedding-pooling-params # doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
......
...@@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
) )
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
......
...@@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority,
) )
generators.append(generator) generators.append(generator)
......
...@@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority,
) )
generators.append(generator) generators.append(generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment