Unverified Commit e22bb037 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

fix: Properly handle multiple text components from request (#5196)

parent 6fe88553
......@@ -24,6 +24,7 @@ from ..multimodal_utils import (
MultiModalRequest,
MyRequestOutput,
ProcessMixIn,
extract_user_text,
vLLMMultimodalRequest,
)
......@@ -157,10 +158,7 @@ class ProcessorHandler(ProcessMixIn):
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
user_text = extract_user_text(raw_request.messages)
prompt = template.replace("<prompt>", user_text)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor,
CompletionsProcessor,
......@@ -34,6 +35,7 @@ __all__ = [
"CompletionsProcessor",
"ProcessMixIn",
"encode_image_embeddings",
"extract_user_text",
"get_encoder_components",
"get_http_client",
"ImageLoader",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility functions for processing chat messages."""
from typing import List
from dynamo.vllm.multimodal_utils.protocol import ChatMessage
def extract_user_text(messages: List[ChatMessage]) -> str:
"""Extract and concatenate text content from user messages."""
# This function finds all text content items from "user" role messages,
# and concatenates them. For multi-turn conversation, it adds a newline
# between each turn. This is not a perfect solution as we encode multi-turn
# conversation as a single turn. However, multi-turn conversation in a
# single request is not well defined in the spec.
# TODO: Revisit this later when adding multi-turn conversation support.
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))
if not user_texts:
raise ValueError("No text content found in user messages")
# Join all user turns with newline separator
return "\n".join(user_texts)
......@@ -59,12 +59,6 @@ class ProcessMixIn(ProcessMixInRequired):
Mixin for pre and post processing for vLLM
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
......
......@@ -18,7 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
......@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel):
return SamplingParams(**v)
return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: json.loads(msgspec.json.encode(v))},
)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for chat message utility functions."""
import pytest
from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.protocol import (
ChatMessage,
ImageContent,
ImageURLDetail,
TextContent,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def test_extract_user_text_single_message():
"""Test extracting text from a single user message with one text content."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="Hello, world!")]
)
]
result = extract_user_text(messages)
assert result == "Hello, world!"
def test_extract_user_text_multiple_text_parts():
"""Test extracting text from a user message with multiple text content items."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text="First part "),
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
),
TextContent(type="text", text="second part"),
],
)
]
result = extract_user_text(messages)
assert result == "First part second part"
def test_extract_user_text_multi_turn():
"""Test extracting text from multi-turn conversation."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="First question")]
),
ChatMessage(
role="assistant", content=[TextContent(type="text", text="First answer")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="Second question")]
),
]
result = extract_user_text(messages)
assert result == "First question\nSecond question"
def test_extract_user_text_only_images():
"""Test that ValueError is raised when messages contain only images."""
messages = [
ChatMessage(
role="user",
content=[
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
)
],
)
]
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_empty_messages():
"""Test that ValueError is raised when messages list is empty."""
messages: list[ChatMessage] = []
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_no_user_messages():
"""Test that ValueError is raised when there are no user role messages."""
messages = [
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Just an assistant message")],
)
]
with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)
def test_extract_user_text_mixed_roles():
"""Test extracting text only from user messages, ignoring other roles."""
messages = [
ChatMessage(
role="system", content=[TextContent(type="text", text="System prompt")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 1")]
),
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Assistant response")],
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 2")]
),
]
result = extract_user_text(messages)
assert result == "User message 1\nUser message 2"
def test_extract_user_text_empty_text_content():
"""Test that empty text content items are ignored."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text=""),
TextContent(type="text", text="Valid text"),
TextContent(type="text", text=""),
],
)
]
result = extract_user_text(messages)
assert result == "Valid text"
......@@ -27,6 +27,7 @@ from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_message_utils import extract_user_text
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import (
MultiModalInput,
......@@ -203,15 +204,7 @@ class Processor(ProcessMixIn):
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text - find the text content item
user_text = None
for message in raw_request.messages:
for item in message.content:
if item.type == "text":
user_text = item.text
break
if user_text is None:
raise ValueError("No text content found in the request messages")
user_text = extract_user_text(raw_request.messages)
prompt = template.replace("<prompt>", user_text)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utility functions for processing chat messages."""
def extract_user_text(messages) -> str:
"""Extract and concatenate text content from user messages."""
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))
if not user_texts:
raise ValueError("No text content found in user messages")
# Join all user turns with newline separator
return "\n".join(user_texts)
......@@ -18,7 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
......@@ -89,9 +89,13 @@ class vLLMGenerateRequest(BaseModel):
return SamplingParams(**v)
return v
@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment