Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35bd2151
Unverified
Commit
35bd2151
authored
Oct 01, 2024
by
Sebastian Schoennenbeck
Committed by
GitHub
Oct 01, 2024
Browse files
[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)
parent
1fe0a426
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
53 additions
and
5 deletions
+53
-5
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+4
-0
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+5
-0
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+16
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+3
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+22
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-0
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-0
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+1
-0
No files found.
vllm/engine/async_llm_engine.py
View file @
35bd2151
...
...
@@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
...
...
@@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
...
...
@@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
pooling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
...
...
vllm/engine/multiprocessing/__init__.py
View file @
35bd2151
...
...
@@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
priority
:
int
=
0
@
overload
# DEPRECATED
def
__init__
(
...
...
@@ -41,6 +42,7 @@ class RPCProcessRequest:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
...
...
@@ -53,6 +55,7 @@ class RPCProcessRequest:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
...
...
@@ -68,6 +71,7 @@ class RPCProcessRequest:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
...
...
@@ -84,6 +88,7 @@ class RPCProcessRequest:
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
priority
=
priority
@
dataclass
...
...
vllm/engine/multiprocessing/client.py
View file @
35bd2151
...
...
@@ -380,6 +380,7 @@ class MQLLMEngineClient:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
...
...
@@ -392,6 +393,7 @@ class MQLLMEngineClient:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
...
...
@@ -407,6 +409,7 @@ class MQLLMEngineClient:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
...
@@ -425,6 +428,9 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if
inputs
is
not
None
:
prompt
=
inputs
...
...
@@ -433,7 +439,7 @@ class MQLLMEngineClient:
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
prompt_adapter_request
,
priority
)
@
overload
# DEPRECATED
def
encode
(
...
...
@@ -444,6 +450,7 @@ class MQLLMEngineClient:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
...
...
@@ -455,6 +462,7 @@ class MQLLMEngineClient:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
...
...
@@ -469,6 +477,7 @@ class MQLLMEngineClient:
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
...
@@ -496,7 +505,7 @@ class MQLLMEngineClient:
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
lora_request
,
trace_headers
,
priority
)
async
def
_process_request
(
self
,
...
...
@@ -505,7 +514,8 @@ class MQLLMEngineClient:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
...
...
@@ -550,7 +560,9 @@ class MQLLMEngineClient:
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts
=
(
request_bytes
,
...
...
vllm/engine/protocol.py
View file @
35bd2151
...
...
@@ -40,7 +40,8 @@ class EngineClient(Protocol):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request."""
...
...
...
@@ -52,6 +53,7 @@ class EngineClient(Protocol):
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
...
...
...
vllm/entrypoints/openai/protocol.py
View file @
35bd2151
...
...
@@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-chat-completion-extra-params
...
...
@@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-completion-extra-params
...
...
@@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
35bd2151
...
...
@@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
35bd2151
...
...
@@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
35bd2151
...
...
@@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment