Unverified Commit e22bb037 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

fix: Properly handle multiple text components from request (#5196)

parent 6fe88553
...@@ -24,6 +24,7 @@ from ..multimodal_utils import ( ...@@ -24,6 +24,7 @@ from ..multimodal_utils import (
MultiModalRequest, MultiModalRequest,
MyRequestOutput, MyRequestOutput,
ProcessMixIn, ProcessMixIn,
extract_user_text,
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
...@@ -157,10 +158,7 @@ class ProcessorHandler(ProcessMixIn): ...@@ -157,10 +158,7 @@ class ProcessorHandler(ProcessMixIn):
raise ValueError("prompt_template must contain '<prompt>' placeholder") raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text # Safely extract user text
try: user_text = extract_user_text(raw_request.messages)
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text) prompt = template.replace("<prompt>", user_text)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.chat_processor import ( from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor, ChatProcessor,
CompletionsProcessor, CompletionsProcessor,
...@@ -34,6 +35,7 @@ __all__ = [ ...@@ -34,6 +35,7 @@ __all__ = [
"CompletionsProcessor", "CompletionsProcessor",
"ProcessMixIn", "ProcessMixIn",
"encode_image_embeddings", "encode_image_embeddings",
"extract_user_text",
"get_encoder_components", "get_encoder_components",
"get_http_client", "get_http_client",
"ImageLoader", "ImageLoader",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility functions for processing chat messages."""
from typing import List
from dynamo.vllm.multimodal_utils.protocol import ChatMessage
def extract_user_text(messages: List[ChatMessage]) -> str:
"""Extract and concatenate text content from user messages."""
# This function finds all text content items from "user" role messages,
# and concatenates them. For multi-turn conversation, it adds a newline
# between each turn. This is not a perfect solution as we encode multi-turn
# conversation as a single turn. However, multi-turn conversation in a
# single request is not well defined in the spec.
# TODO: Revisit this later when adding multi-turn conversation support.
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))
if not user_texts:
raise ValueError("No text content found in user messages")
# Join all user turns with newline separator
return "\n".join(user_texts)
...@@ -59,12 +59,6 @@ class ProcessMixIn(ProcessMixInRequired): ...@@ -59,12 +59,6 @@ class ProcessMixIn(ProcessMixInRequired):
Mixin for pre and post processing for vLLM Mixin for pre and post processing for vLLM
""" """
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self): def __init__(self):
pass pass
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
...@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel): ...@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel):
return SamplingParams(**v) return SamplingParams(**v)
return v return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: json.loads(msgspec.json.encode(v))},
) )
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for chat message utility functions."""
import pytest
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.protocol import (
ChatMessage,
ImageContent,
ImageURLDetail,
TextContent,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def test_extract_user_text_single_message():
"""Test extracting text from a single user message with one text content."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="Hello, world!")]
)
]
result = extract_user_text(messages)
assert result == "Hello, world!"
def test_extract_user_text_multiple_text_parts():
"""Test extracting text from a user message with multiple text content items."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text="First part "),
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
),
TextContent(type="text", text="second part"),
],
)
]
result = extract_user_text(messages)
assert result == "First part second part"
def test_extract_user_text_multi_turn():
"""Test extracting text from multi-turn conversation."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="First question")]
),
ChatMessage(
role="assistant", content=[TextContent(type="text", text="First answer")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="Second question")]
),
]
result = extract_user_text(messages)
assert result == "First question\nSecond question"
def test_extract_user_text_only_images():
"""Test that ValueError is raised when messages contain only images."""
messages = [
ChatMessage(
role="user",
content=[
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
)
],
)
]
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_empty_messages():
"""Test that ValueError is raised when messages list is empty."""
messages: list[ChatMessage] = []
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_no_user_messages():
"""Test that ValueError is raised when there are no user role messages."""
messages = [
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Just an assistant message")],
)
]
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_mixed_roles():
"""Test extracting text only from user messages, ignoring other roles."""
messages = [
ChatMessage(
role="system", content=[TextContent(type="text", text="System prompt")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 1")]
),
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Assistant response")],
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 2")]
),
]
result = extract_user_text(messages)
assert result == "User message 1\nUser message 2"
def test_extract_user_text_empty_text_content():
"""Test that empty text content items are ignored."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text=""),
TextContent(type="text", text="Valid text"),
TextContent(type="text", text=""),
],
)
]
result = extract_user_text(messages)
assert result == "Valid text"
...@@ -27,6 +27,7 @@ from dynamo.runtime.logging import configure_dynamo_logging ...@@ -27,6 +27,7 @@ from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module # To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_message_utils import extract_user_text
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import ( from utils.protocol import (
MultiModalInput, MultiModalInput,
...@@ -203,15 +204,7 @@ class Processor(ProcessMixIn): ...@@ -203,15 +204,7 @@ class Processor(ProcessMixIn):
if "<prompt>" not in template: if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder") raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text - find the text content item user_text = extract_user_text(raw_request.messages)
user_text = None
for message in raw_request.messages:
for item in message.content:
if item.type == "text":
user_text = item.text
break
if user_text is None:
raise ValueError("No text content found in the request messages")
prompt = template.replace("<prompt>", user_text) prompt = template.replace("<prompt>", user_text)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility functions for processing chat messages."""
def extract_user_text(messages) -> str:
"""Extract and concatenate text content from user messages."""
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))
if not user_texts:
raise ValueError("No text content found in user messages")
# Join all user turns with newline separator
return "\n".join(user_texts)
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
...@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel): ...@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel):
return SamplingParams(**v) return SamplingParams(**v)
return v return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)},
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment