Unverified Commit 76515f30 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Frontend] Use MQLLMEngine for embeddings models too (#8584)

parent 855c8ae2
...@@ -2,6 +2,7 @@ from dataclasses import dataclass ...@@ -2,6 +2,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError): ...@@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):
@dataclass @dataclass
class RPCGenerateRequest: class RPCProcessRequest:
inputs: PromptInputs inputs: PromptInputs
sampling_params: SamplingParams params: Union[SamplingParams, PoolingParams]
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
...@@ -55,7 +56,7 @@ class RPCStartupResponse: ...@@ -55,7 +56,7 @@ class RPCStartupResponse:
tracing_enabled: bool tracing_enabled: bool
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest] RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
......
...@@ -11,6 +11,7 @@ import zmq.asyncio ...@@ -11,6 +11,7 @@ import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined] from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -19,8 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -19,8 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest, RPCError, RPCHealthRequest,
RPCHealthRequest, RPCStartupRequest, RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse) RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
...@@ -111,20 +112,8 @@ class MQLLMEngineClient: ...@@ -111,20 +112,8 @@ class MQLLMEngineClient:
@staticmethod @staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs): def is_unsupported_config(engine_args: AsyncEngineArgs):
if engine_args.pipeline_parallel_size > 1: # Pipeline parallel not yet supported
return True return engine_args.pipeline_parallel_size > 1
is_embedding = ModelConfig(
model=engine_args.model,
revision=engine_args.revision,
tokenizer=engine_args.model,
tokenizer_mode="auto",
trust_remote_code=engine_args.trust_remote_code,
quantization=engine_args.quantization,
seed=0,
dtype="auto").embedding_mode
return is_embedding
@contextmanager @contextmanager
def get_data_socket(self) -> Iterator[Socket]: def get_data_socket(self) -> Iterator[Socket]:
...@@ -382,12 +371,9 @@ class MQLLMEngineClient: ...@@ -382,12 +371,9 @@ class MQLLMEngineClient:
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with) return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
async def generate( def generate(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
...@@ -396,6 +382,67 @@ class MQLLMEngineClient: ...@@ -396,6 +382,67 @@ class MQLLMEngineClient:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)
async def _process_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out. # If already dead, error out.
...@@ -410,19 +457,19 @@ class MQLLMEngineClient: ...@@ -410,19 +457,19 @@ class MQLLMEngineClient:
try: try:
# 2) Detach logits processors so that they can be pickled # 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower) # separately (may require cloudpickle which is slower)
if sampling_params.logits_processors: if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy # Defensive shallow copy
sampling_params = copy.copy(sampling_params) params = copy.copy(params)
logits_processors = sampling_params.logits_processors logits_processors = params.logits_processors
sampling_params.logits_processors = None params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors) lp_bytes = cloudpickle.dumps(logits_processors)
else: else:
lp_bytes = None lp_bytes = None
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCGenerateRequest( RPCProcessRequest(
inputs=inputs, inputs=inputs,
sampling_params=sampling_params, params=params,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
...@@ -452,8 +499,3 @@ class MQLLMEngineClient: ...@@ -452,8 +499,3 @@ class MQLLMEngineClient:
await self.abort(request_id) await self.abort(request_id)
finally: finally:
self.output_queues.pop(request_id) self.output_queues.pop(request_id)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
...@@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Union ...@@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
import zmq import zmq
from vllm import AsyncEngineArgs, LLMEngine from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -15,8 +15,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -15,8 +15,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest, RPCError, RPCHealthRequest,
RPCHealthRequest, RPCStartupRequest, RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse) RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -39,8 +39,8 @@ class MQLLMEngine: ...@@ -39,8 +39,8 @@ class MQLLMEngine:
in concurrnet manner. It runs a background loop and uses zeromq to in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc. receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine.generate` is kicked off when a new The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCGenerateRequest is received by the input_socket. RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests, The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal adds them to the LLMEngine if there are any, calls the internal
...@@ -213,12 +213,13 @@ class MQLLMEngine: ...@@ -213,12 +213,13 @@ class MQLLMEngine:
frames = self.input_socket.recv_multipart(copy=False) frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer) request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCGenerateRequest): if isinstance(request, RPCProcessRequest):
if len(frames) > 1: if len(frames) > 1:
# Use cloudpickle for logits processors # Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer) lprocs = cloudpickle.loads(frames[1].buffer)
request.sampling_params.logits_processors = lprocs request.params.logits_processors = lprocs
self._handle_generate_request(request) self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest): elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request) self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest): elif isinstance(request, RPCHealthRequest):
...@@ -231,8 +232,8 @@ class MQLLMEngine: ...@@ -231,8 +232,8 @@ class MQLLMEngine:
self._send_unhealthy(e) self._send_unhealthy(e)
raise e raise e
def _handle_generate_request(self, request: RPCGenerateRequest): def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCGenerateRequest by adding it to the LLMEngine.""" """Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id request_id = request.request_id
if self._errored_with is not None: if self._errored_with is not None:
...@@ -245,7 +246,7 @@ class MQLLMEngine: ...@@ -245,7 +246,7 @@ class MQLLMEngine:
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
inputs=request.inputs, inputs=request.inputs,
params=request.sampling_params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request) prompt_adapter_request=request.prompt_adapter_request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment