Unverified Commit b8a45721 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Use helper function to generate dummy messages in OpenAI MM tests (#26875)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 302ef403
...@@ -53,22 +53,35 @@ def base64_encoded_audio() -> dict[str, str]: ...@@ -53,22 +53,35 @@ def base64_encoded_audio() -> dict[str, str]:
} }
@pytest.mark.asyncio def dummy_messages_from_audio_url(
@pytest.mark.parametrize("model_name", [MODEL_NAME]) audio_urls: str | list[str],
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) content_text: str = "What's happening in this audio?",
async def test_single_chat_session_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str
): ):
messages = [ if isinstance(audio_urls, str):
audio_urls = [audio_urls]
return [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "audio_url", "audio_url": {"url": audio_url}}, *(
{"type": "text", "text": "What's happening in this audio?"}, {"type": "audio_url", "audio_url": {"url": audio_url}}
for audio_url in audio_urls
),
{"type": "text", "text": content_text},
], ],
} }
] ]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str
):
messages = dummy_messages_from_audio_url(audio_url)
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -138,20 +151,9 @@ async def test_single_chat_session_audio_base64encoded( ...@@ -138,20 +151,9 @@ async def test_single_chat_session_audio_base64encoded(
audio_url: str, audio_url: str,
base64_encoded_audio: dict[str, str], base64_encoded_audio: dict[str, str],
): ):
messages = [ messages = dummy_messages_from_audio_url(
{ f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
"role": "user", )
"content": [
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" # noqa: E501
},
},
{"type": "text", "text": "What's happening in this audio?"},
],
}
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -252,15 +254,7 @@ async def test_single_chat_session_input_audio( ...@@ -252,15 +254,7 @@ async def test_single_chat_session_input_audio(
async def test_chat_streaming_audio( async def test_chat_streaming_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str client: openai.AsyncOpenAI, model_name: str, audio_url: str
): ):
messages = [ messages = dummy_messages_from_audio_url(audio_url)
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": audio_url}},
{"type": "text", "text": "What's happening in this audio?"},
],
}
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -365,18 +359,7 @@ async def test_chat_streaming_input_audio( ...@@ -365,18 +359,7 @@ async def test_chat_streaming_input_audio(
async def test_multi_audio_input( async def test_multi_audio_input(
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str] client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
): ):
messages = [ messages = dummy_messages_from_audio_url(audio_urls)
{
"role": "user",
"content": [
*(
{"type": "audio_url", "audio_url": {"url": audio_url}}
for audio_url in audio_urls
),
{"type": "text", "text": "What's happening in this audio?"},
],
}
]
if len(audio_urls) > MAXIMUM_AUDIOS: if len(audio_urls) > MAXIMUM_AUDIOS:
with pytest.raises(openai.BadRequestError): # test multi-audio input with pytest.raises(openai.BadRequestError): # test multi-audio input
......
...@@ -55,22 +55,35 @@ def base64_encoded_video() -> dict[str, str]: ...@@ -55,22 +55,35 @@ def base64_encoded_video() -> dict[str, str]:
} }
@pytest.mark.asyncio def dummy_messages_from_video_url(
@pytest.mark.parametrize("model_name", [MODEL_NAME]) video_urls: str | list[str],
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) content_text: str = "What's in this video?",
async def test_single_chat_session_video(
client: openai.AsyncOpenAI, model_name: str, video_url: str
): ):
messages = [ if isinstance(video_urls, str):
video_urls = [video_urls]
return [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "video_url", "video_url": {"url": video_url}}, *(
{"type": "text", "text": "What's in this video?"}, {"type": "video_url", "video_url": {"url": video_url}}
for video_url in video_urls
),
{"type": "text", "text": content_text},
], ],
} }
] ]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_single_chat_session_video(
client: openai.AsyncOpenAI, model_name: str, video_url: str
):
messages = dummy_messages_from_video_url(video_url)
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -137,15 +150,7 @@ async def test_error_on_invalid_video_url_type( ...@@ -137,15 +150,7 @@ async def test_error_on_invalid_video_url_type(
async def test_single_chat_session_video_beamsearch( async def test_single_chat_session_video_beamsearch(
client: openai.AsyncOpenAI, model_name: str, video_url: str client: openai.AsyncOpenAI, model_name: str, video_url: str
): ):
messages = [ messages = dummy_messages_from_video_url(video_url)
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": video_url}},
{"type": "text", "text": "What's in this video?"},
],
}
]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -172,20 +177,9 @@ async def test_single_chat_session_video_base64encoded( ...@@ -172,20 +177,9 @@ async def test_single_chat_session_video_base64encoded(
video_url: str, video_url: str,
base64_encoded_video: dict[str, str], base64_encoded_video: dict[str, str],
): ):
messages = [ messages = dummy_messages_from_video_url(
{ f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
"role": "user", )
"content": [
{
"type": "video_url",
"video_url": {
"url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501
},
},
{"type": "text", "text": "What's in this video?"},
],
}
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -231,20 +225,10 @@ async def test_single_chat_session_video_base64encoded_beamsearch( ...@@ -231,20 +225,10 @@ async def test_single_chat_session_video_base64encoded_beamsearch(
video_url: str, video_url: str,
base64_encoded_video: dict[str, str], base64_encoded_video: dict[str, str],
): ):
messages = [ messages = dummy_messages_from_video_url(
{ f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
"role": "user", )
"content": [
{
"type": "video_url",
"video_url": {
"url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501
},
},
{"type": "text", "text": "What's in this video?"},
],
}
]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
...@@ -265,15 +249,7 @@ async def test_single_chat_session_video_base64encoded_beamsearch( ...@@ -265,15 +249,7 @@ async def test_single_chat_session_video_base64encoded_beamsearch(
async def test_chat_streaming_video( async def test_chat_streaming_video(
client: openai.AsyncOpenAI, model_name: str, video_url: str client: openai.AsyncOpenAI, model_name: str, video_url: str
): ):
messages = [ messages = dummy_messages_from_video_url(video_url)
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": video_url}},
{"type": "text", "text": "What's in this video?"},
],
}
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -318,18 +294,7 @@ async def test_chat_streaming_video( ...@@ -318,18 +294,7 @@ async def test_chat_streaming_video(
async def test_multi_video_input( async def test_multi_video_input(
client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] client: openai.AsyncOpenAI, model_name: str, video_urls: list[str]
): ):
messages = [ messages = dummy_messages_from_video_url(video_urls)
{
"role": "user",
"content": [
*(
{"type": "video_url", "video_url": {"url": video_url}}
for video_url in video_urls
),
{"type": "text", "text": "What's in this video?"},
],
}
]
if len(video_urls) > MAXIMUM_VIDEOS: if len(video_urls) > MAXIMUM_VIDEOS:
with pytest.raises(openai.BadRequestError): # test multi-video input with pytest.raises(openai.BadRequestError): # test multi-video input
......
...@@ -78,6 +78,27 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]: ...@@ -78,6 +78,27 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]:
} }
def dummy_messages_from_image_url(
image_urls: str | list[str],
content_text: str = "What's in this image?",
):
if isinstance(image_urls, str):
image_urls = [image_urls]
return [
{
"role": "user",
"content": [
*(
{"type": "image_url", "image_url": {"url": image_url}}
for image_url in image_urls
),
{"type": "text", "text": content_text},
],
}
]
def get_hf_prompt_tokens(model_name, content, image_url): def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, num_crops=4 model_name, trust_remote_code=True, num_crops=4
...@@ -107,15 +128,7 @@ async def test_single_chat_session_image( ...@@ -107,15 +128,7 @@ async def test_single_chat_session_image(
client: openai.AsyncOpenAI, model_name: str, image_url: str client: openai.AsyncOpenAI, model_name: str, image_url: str
): ):
content_text = "What's in this image?" content_text = "What's in this image?"
messages = [ messages = dummy_messages_from_image_url(image_url, content_text)
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": content_text},
],
}
]
max_completion_tokens = 10 max_completion_tokens = 10
# test single completion # test single completion
...@@ -188,15 +201,8 @@ async def test_error_on_invalid_image_url_type( ...@@ -188,15 +201,8 @@ async def test_error_on_invalid_image_url_type(
async def test_single_chat_session_image_beamsearch( async def test_single_chat_session_image_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str client: openai.AsyncOpenAI, model_name: str, image_url: str
): ):
messages = [ content_text = "What's in this image?"
{ messages = dummy_messages_from_image_url(image_url, content_text)
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "What's in this image?"},
],
}
]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -226,20 +232,10 @@ async def test_single_chat_session_image_base64encoded( ...@@ -226,20 +232,10 @@ async def test_single_chat_session_image_base64encoded(
base64_encoded_image: dict[str, str], base64_encoded_image: dict[str, str],
): ):
content_text = "What's in this image?" content_text = "What's in this image?"
messages = [ messages = dummy_messages_from_image_url(
{ f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}",
"role": "user", content_text,
"content": [ )
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501
},
},
{"type": "text", "text": content_text},
],
}
]
max_completion_tokens = 10 max_completion_tokens = 10
# test single completion # test single completion
...@@ -293,20 +289,10 @@ async def test_single_chat_session_image_base64encoded_beamsearch( ...@@ -293,20 +289,10 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
raw_image_url = TEST_IMAGE_ASSETS[image_idx] raw_image_url = TEST_IMAGE_ASSETS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
messages = [ messages = dummy_messages_from_image_url(
{ f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
"role": "user", )
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501
},
},
{"type": "text", "text": "What's in this image?"},
],
}
]
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
...@@ -326,15 +312,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( ...@@ -326,15 +312,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
async def test_chat_streaming_image( async def test_chat_streaming_image(
client: openai.AsyncOpenAI, model_name: str, image_url: str client: openai.AsyncOpenAI, model_name: str, image_url: str
): ):
messages = [ messages = dummy_messages_from_image_url(image_url)
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "What's in this image?"},
],
}
]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -381,18 +359,7 @@ async def test_chat_streaming_image( ...@@ -381,18 +359,7 @@ async def test_chat_streaming_image(
async def test_multi_image_input( async def test_multi_image_input(
client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]
): ):
messages = [ messages = dummy_messages_from_image_url(image_urls)
{
"role": "user",
"content": [
*(
{"type": "image_url", "image_url": {"url": image_url}}
for image_url in image_urls
),
{"type": "text", "text": "What's in this image?"},
],
}
]
if len(image_urls) > MAXIMUM_IMAGES: if len(image_urls) > MAXIMUM_IMAGES:
with pytest.raises(openai.BadRequestError): # test multi-image input with pytest.raises(openai.BadRequestError): # test multi-image input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment