Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
76515f30
Unverified
Commit
76515f30
authored
Sep 19, 2024
by
Nick Hill
Committed by
GitHub
Sep 19, 2024
Browse files
[Frontend] Use MQLLMEngine for embeddings models too (#8584)
parent
855c8ae2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
46 deletions
+90
-46
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+4
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+74
-32
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+12
-11
No files found.
vllm/engine/multiprocessing/__init__.py
View file @
76515f30
...
...
@@ -2,6 +2,7 @@ from dataclasses import dataclass
from
enum
import
Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
vllm
import
PoolingParams
from
vllm.inputs
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
...
...
@@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
class
RPC
Generate
Request
:
class
RPC
Process
Request
:
inputs
:
PromptInputs
s
ampling
_p
arams
:
Samp
lingParams
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
...
...
@@ -55,7 +56,7 @@ class RPCStartupResponse:
tracing_enabled
:
bool
RPC_REQUEST_T
=
Union
[
RPC
Generate
Request
,
RPCAbortRequest
,
RPCHealthRequest
,
RPC_REQUEST_T
=
Union
[
RPC
Process
Request
,
RPCAbortRequest
,
RPCHealthRequest
,
RPCStartupRequest
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCError
]
...
...
vllm/engine/multiprocessing/client.py
View file @
76515f30
...
...
@@ -11,6 +11,7 @@ import zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq.asyncio
import
Socket
from
vllm
import
PoolingParams
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
...
...
@@ -19,8 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPC
Generate
Request
,
RPC
Health
Request
,
RPCStartupRequest
,
RPCError
,
RPC
Health
Request
,
RPC
Process
Request
,
RPCStartupRequest
,
RPCStartupResponse
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
...
...
@@ -111,20 +112,8 @@ class MQLLMEngineClient:
@
staticmethod
def
is_unsupported_config
(
engine_args
:
AsyncEngineArgs
):
if
engine_args
.
pipeline_parallel_size
>
1
:
return
True
is_embedding
=
ModelConfig
(
model
=
engine_args
.
model
,
revision
=
engine_args
.
revision
,
tokenizer
=
engine_args
.
model
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
engine_args
.
trust_remote_code
,
quantization
=
engine_args
.
quantization
,
seed
=
0
,
dtype
=
"auto"
).
embedding_mode
return
is_embedding
# Pipeline parallel not yet supported
return
engine_args
.
pipeline_parallel_size
>
1
@
contextmanager
def
get_data_socket
(
self
)
->
Iterator
[
Socket
]:
...
...
@@ -382,12 +371,9 @@ class MQLLMEngineClient:
@
property
def
dead_error
(
self
)
->
BaseException
:
if
self
.
_errored_with
is
not
None
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
else
:
return
ENGINE_DEAD_ERROR
()
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
async
def
generate
(
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
...
...
@@ -396,6 +382,67 @@ class MQLLMEngineClient:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
async
def
_process_request
(
self
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
...
...
@@ -410,19 +457,19 @@ class MQLLMEngineClient:
try
:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if
sampling_
params
.
logits_processors
:
if
isinstance
(
params
,
SamplingParams
)
and
params
.
logits_processors
:
# Defensive shallow copy
sampling_
params
=
copy
.
copy
(
sampling_
params
)
logits_processors
=
sampling_
params
.
logits_processors
sampling_
params
.
logits_processors
=
None
params
=
copy
.
copy
(
params
)
logits_processors
=
params
.
logits_processors
params
.
logits_processors
=
None
lp_bytes
=
cloudpickle
.
dumps
(
logits_processors
)
else
:
lp_bytes
=
None
request_bytes
=
pickle
.
dumps
(
RPC
Generate
Request
(
RPC
Process
Request
(
inputs
=
inputs
,
sampling_params
=
sampling_
params
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
...
...
@@ -452,8 +499,3 @@ class MQLLMEngineClient:
await
self
.
abort
(
request_id
)
finally
:
self
.
output_queues
.
pop
(
request_id
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
vllm/engine/multiprocessing/engine.py
View file @
76515f30
...
...
@@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Union
import
cloudpickle
import
zmq
from
vllm
import
AsyncEngineArgs
,
LLMEngine
from
vllm
import
AsyncEngineArgs
,
LLMEngine
,
SamplingParams
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf conflicts with isort for this block
...
...
@@ -15,8 +15,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPC
Generate
Request
,
RPC
Health
Request
,
RPCStartupRequest
,
RPCError
,
RPC
Health
Request
,
RPC
Process
Request
,
RPCStartupRequest
,
RPCStartupResponse
)
# yapf: enable
from
vllm.logger
import
init_logger
...
...
@@ -39,8 +39,8 @@ class MQLLMEngine:
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine
.
generate
`
is kicked off when a new
RPC
Generate
Request is received by the input_socket.
The :class:`LLMEngine
`
generate
or encode process
is kicked off when a new
RPC
Process
Request is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
...
...
@@ -213,12 +213,13 @@ class MQLLMEngine:
frames
=
self
.
input_socket
.
recv_multipart
(
copy
=
False
)
request
=
pickle
.
loads
(
frames
[
0
].
buffer
)
if
isinstance
(
request
,
RPC
Generate
Request
):
if
isinstance
(
request
,
RPC
Process
Request
):
if
len
(
frames
)
>
1
:
# Use cloudpickle for logits processors
assert
isinstance
(
request
.
params
,
SamplingParams
)
lprocs
=
cloudpickle
.
loads
(
frames
[
1
].
buffer
)
request
.
sampling_
params
.
logits_processors
=
lprocs
self
.
_handle_
generate
_request
(
request
)
request
.
params
.
logits_processors
=
lprocs
self
.
_handle_
process
_request
(
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
self
.
_handle_abort_request
(
request
)
elif
isinstance
(
request
,
RPCHealthRequest
):
...
...
@@ -231,8 +232,8 @@ class MQLLMEngine:
self
.
_send_unhealthy
(
e
)
raise
e
def
_handle_
generate
_request
(
self
,
request
:
RPC
Generate
Request
):
"""Handle RPC
Generate
Request by adding it to the LLMEngine."""
def
_handle_
process
_request
(
self
,
request
:
RPC
Process
Request
):
"""Handle RPC
Process
Request by adding it to the LLMEngine."""
request_id
=
request
.
request_id
if
self
.
_errored_with
is
not
None
:
...
...
@@ -245,7 +246,7 @@ class MQLLMEngine:
self
.
engine
.
add_request
(
request_id
=
request_id
,
inputs
=
request
.
inputs
,
params
=
request
.
sampling_
params
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
prompt_adapter_request
=
request
.
prompt_adapter_request
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment