Unverified Commit 676f82ae authored by Varun Chawla's avatar Varun Chawla Committed by GitHub
Browse files

Add validation to reject non-text content in system messages (#34072)


Signed-off-by: default avatarVarun Chawla <varun_6april@hotmail.com>
parent 81bfc21a
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
...@@ -233,3 +233,140 @@ async def test_chat_error_stream(): ...@@ -233,3 +233,140 @@ async def test_chat_error_stream():
f"Expected error message in chunks: {chunks}" f"Expected error message in chunks: {chunks}"
) )
assert chunks[-1] == "data: [DONE]\n\n" assert chunks[-1] == "data: [DONE]\n\n"
@pytest.mark.parametrize(
"image_content",
[
[{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}],
[{"image_url": {"url": "https://example.com/image.jpg"}}],
],
)
def test_system_message_warns_on_image(image_content):
"""Test that system messages with image content trigger a warning."""
with patch(
"vllm.entrypoints.openai.chat_completion.protocol.logger"
) as mock_logger:
ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": image_content,
}
],
)
mock_logger.warning_once.assert_called()
call_args = str(mock_logger.warning_once.call_args)
assert "System messages should only contain text" in call_args
assert "image_url" in call_args
def test_system_message_accepts_text():
"""Test that system messages can contain text content."""
# Should not raise an exception
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
],
)
assert request.messages[0]["role"] == "system"
def test_system_message_accepts_text_array():
"""Test that system messages can contain an array with text content."""
# Should not raise an exception
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
],
)
assert request.messages[0]["role"] == "system"
def test_user_message_accepts_image():
"""Test that user messages can still contain image content."""
# Should not raise an exception
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.jpg"},
},
],
},
],
)
assert request.messages[0]["role"] == "user"
@pytest.mark.parametrize(
"audio_content",
[
[
{
"type": "input_audio",
"input_audio": {"data": "base64data", "format": "wav"},
}
],
[{"input_audio": {"data": "base64data", "format": "wav"}}],
],
)
def test_system_message_warns_on_audio(audio_content):
"""Test that system messages with audio content trigger a warning."""
with patch(
"vllm.entrypoints.openai.chat_completion.protocol.logger"
) as mock_logger:
ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": audio_content,
}
],
)
mock_logger.warning_once.assert_called()
call_args = str(mock_logger.warning_once.call_args)
assert "System messages should only contain text" in call_args
assert "input_audio" in call_args
@pytest.mark.parametrize(
"video_content",
[
[{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}],
[{"video_url": {"url": "https://example.com/video.mp4"}}],
],
)
def test_system_message_warns_on_video(video_content):
"""Test that system messages with video content trigger a warning."""
with patch(
"vllm.entrypoints.openai.chat_completion.protocol.logger"
) as mock_logger:
ChatCompletionRequest(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": video_content,
}
],
)
mock_logger.warning_once.assert_called()
call_args = str(mock_logger.warning_once.call_args)
assert "System messages should only contain text" in call_args
assert "video_url" in call_args
...@@ -674,3 +674,52 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -674,3 +674,52 @@ class ChatCompletionRequest(OpenAIBaseModel):
"Parameter 'cache_salt' must be a non-empty string if provided." "Parameter 'cache_salt' must be a non-empty string if provided."
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_system_message_content_type(cls, data):
"""Warn if system messages contain non-text content.
According to OpenAI API spec, system messages can only be of type
'text'. We log a warning instead of rejecting to avoid breaking
users who intentionally send multimodal system messages.
See: https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-system_message
"""
if not isinstance(data, dict):
return data
messages = data.get("messages", [])
for msg in messages:
# Check if this is a system message
if isinstance(msg, dict) and msg.get("role") == "system":
content = msg.get("content")
# If content is a list (multimodal format)
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part_type = part.get("type")
# Infer type when 'type' field is not explicit
if part_type is None:
if "image_url" in part or "image_pil" in part:
part_type = "image_url"
elif "image_embeds" in part:
part_type = "image_embeds"
elif "audio_url" in part:
part_type = "audio_url"
elif "input_audio" in part:
part_type = "input_audio"
elif "audio_embeds" in part:
part_type = "audio_embeds"
elif "video_url" in part:
part_type = "video_url"
# Warn about non-text content in system messages
if part_type and part_type != "text":
logger.warning_once(
"System messages should only contain text "
"content according to the OpenAI API spec. "
"Found content type: '%s'.",
part_type,
)
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment