Unverified Commit 7c655279 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1] Use pickle for serializing EngineCoreRequest & Add multimodal inputs to...


[V1] Use pickle for serializing EngineCoreRequest & Add multimodal inputs to EngineCoreRequest (#10245)
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 47db6ec8
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Union
import msgspec import msgspec
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
...@@ -22,7 +23,8 @@ class DetokenizerRequest: ...@@ -22,7 +23,8 @@ class DetokenizerRequest:
include_stop_str_in_output: bool include_stop_str_in_output: bool
class EngineCoreRequest(msgspec.Struct, omit_defaults=True): @dataclass
class EngineCoreRequest:
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec # but this object is currently not playing well with msgspec
...@@ -33,6 +35,9 @@ class EngineCoreRequest(msgspec.Struct, omit_defaults=True): ...@@ -33,6 +35,9 @@ class EngineCoreRequest(msgspec.Struct, omit_defaults=True):
# always be tokenized? # always be tokenized?
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_data: Optional[MultiModalDataDict]
mm_placeholders: Optional[MultiModalPlaceholderDict]
mm_processor_kwargs: Optional[Dict[str, Any]]
sampling_params: SamplingParams sampling_params: SamplingParams
eos_token_id: Optional[int] eos_token_id: Optional[int]
arrival_time: float arrival_time: float
......
...@@ -19,6 +19,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, ...@@ -19,6 +19,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType) EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -315,7 +316,7 @@ class EngineCoreProc(EngineCore): ...@@ -315,7 +316,7 @@ class EngineCoreProc(EngineCore):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding.
decoder_add_req = msgpack.Decoder(EngineCoreRequest) decoder_add_req = PickleEncoder()
decoder_abort_req = msgpack.Decoder(list[str]) decoder_abort_req = msgpack.Decoder(list[str])
with self.make_socket(input_path, zmq.constants.PULL) as socket: with self.make_socket(input_path, zmq.constants.PULL) as socket:
......
...@@ -11,6 +11,7 @@ from vllm.utils import get_open_zmq_ipc_path ...@@ -11,6 +11,7 @@ from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType) EngineCoreRequest, EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -115,7 +116,7 @@ class MPClient(EngineCoreClient): ...@@ -115,7 +116,7 @@ class MPClient(EngineCoreClient):
**kwargs, **kwargs,
): ):
# Serialization setup. # Serialization setup.
self.encoder = msgspec.msgpack.Encoder() self.encoder = PickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
......
...@@ -91,7 +91,10 @@ class Processor: ...@@ -91,7 +91,10 @@ class Processor:
# Make Request for EngineCore. # Make Request for EngineCore.
engine_core_request = EngineCoreRequest( engine_core_request = EngineCoreRequest(
request_id, processed_inputs.get("prompt"), request_id, processed_inputs.get("prompt"),
processed_inputs.get("prompt_token_ids"), sampling_params, processed_inputs.get("prompt_token_ids"),
processed_inputs.get("multi_modal_data"),
processed_inputs.get("multi_modal_placeholders"),
processed_inputs.get("mm_processor_kwargs"), sampling_params,
eos_token_id, arrival_time, lora_request) eos_token_id, arrival_time, lora_request)
return detokenizer_request, engine_core_request return detokenizer_request, engine_core_request
......
import pickle
class PickleEncoder:
def encode(self, obj):
return pickle.dumps(obj)
def decode(self, data):
return pickle.loads(data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment