Unverified Commit 454c28ab authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

fix: sampling params parsing in vllm EPD flow (#5813)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent a4600927
......@@ -274,7 +274,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
)
# Remove the image features from the request as they are not required
request.multimodal_inputs = None
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.info(f"{multi_modal_data}")
......
......@@ -86,13 +86,32 @@ class vLLMGenerateRequest(BaseModel):
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
# Workaround for vLLM SamplingParams serialization/deserialization issue.
#
# Problem: When SamplingParams is serialized via msgspec.json.encode(),
# Python sets are converted to JSON arrays (lists). The serialized dict
# includes private fields like _all_stop_token_ids. Upon deserialization,
# passing this dict to SamplingParams(**dict) causes __post_init__ to fail
# because it expects _all_stop_token_ids to be a set (to call .update()),
# but it's now a list.
#
# Solution: Filter out private fields (starting with '_') which are
# internal state that should be computed by __post_init__, not passed
# from serialized data. Public fields like stop_token_ids are preserved.
v = {k: val for k, val in v.items() if not k.startswith("_")}
return SamplingParams(**v)
return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
"""Serialize SamplingParams, filtering out private fields.
This is the primary fix for the set→list serialization issue.
Private fields like _all_stop_token_ids are filtered out here
so they never get sent over the wire.
"""
serialized = json.loads(msgspec.json.encode(value))
return {k: v for k, v in serialized.items() if not k.startswith("_")}
model_config = ConfigDict(
arbitrary_types_allowed=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment