Unverified Commit 1df84ff4 authored by Mick's avatar Mick Committed by GitHub
Browse files

ci: simplify multi-modality tests by using mixins (#9006)

parent 66d6be08
......@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
if videos:
kwargs["videos"] = videos
if audios:
if self.arch in {
"Gemma3nForConditionalGeneration",
"Qwen2AudioForConditionalGeneration",
if self._processor.__class__.__name__ in {
"Gemma3nProcessor",
"Qwen2AudioProcessor",
}:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
......
......@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image, logger
from sglang.srt.utils import ImageData, load_image, logger
from sglang.utils import get_exception_traceback
......@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_data: Union[str, bytes, ImageData],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
......@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_processor = processor.image_processor
try:
image, image_size = load_image(image_data)
url = image_data.url if isinstance(image_data, ImageData) else image_data
image, image_size = load_image(url)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
image_hash = hash(url)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
......@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
image_hash = hash(url)
if image_aspect_ratio == "pad":
image = expand2square(
image,
......@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
self,
image_data: Union[bytes, str, ImageData],
aspect_ratio: str,
grid_pinpoints: str,
):
if self.cpu_executor is not None:
loop = asyncio.get_event_loop()
......@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, ImageData]],
input_text,
request_obj,
*args,
......
......@@ -110,8 +110,8 @@ suites = {
TestFile("test_utils_update_weights.py", 48),
TestFile("test_vision_chunked_prefill.py", 175),
TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 989),
TestFile("test_vision_openai_server_b.py", 620),
TestFile("test_vision_openai_server_a.py", 403),
TestFile("test_vision_openai_server_b.py", 446),
],
"per-commit-2-gpu": [
TestFile("lora/test_lora_tp.py", 116),
......
......@@ -8,16 +8,28 @@ import unittest
from test_vision_openai_server_common import *
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestQwen2VLServer(TestOpenAIVisionServer):
class TestLlava(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1"
class TestQwen2VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
......@@ -37,11 +49,8 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestQwen2_5_VLServer(TestOpenAIVisionServer):
class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2.5-VL-7B-Instruct"
......@@ -61,9 +70,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestVLMContextLengthIssue(CustomTestCase):
@classmethod
......@@ -137,11 +143,8 @@ class TestVLMContextLengthIssue(CustomTestCase):
# )
# cls.base_url += "/v1"
# def test_video_chat_completion(self):
# pass
class TestMinicpmvServer(TestOpenAIVisionServer):
class TestMinicpmvServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "openbmb/MiniCPM-V-2_6"
......@@ -162,7 +165,7 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
class TestInternVL2_5Server(TestOpenAIVisionServer):
class TestInternVL2_5Server(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "OpenGVLab/InternVL2_5-2B"
......@@ -181,7 +184,7 @@ class TestInternVL2_5Server(TestOpenAIVisionServer):
cls.base_url += "/v1"
class TestMinicpmoServer(TestOpenAIVisionServer):
class TestMinicpmoServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "openbmb/MiniCPM-o-2_6"
......@@ -201,12 +204,8 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
self._test_audio_ambient_completion()
class TestMimoVLServer(TestOpenAIVisionServer):
class TestMimoVLServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "XiaomiMiMo/MiMo-VL-7B-RL"
......@@ -228,6 +227,95 @@ class TestMimoVLServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
class TestVILAServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--trust-remote-code",
"--context-length=65536",
f"--revision={cls.revision}",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod
def setUpClass(cls):
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
from huggingface_hub import constants, snapshot_download
snapshot_download(
"microsoft/Phi-4-multimodal-instruct",
allow_patterns=["**/adapter_config.json"],
)
cls.model = "microsoft/Phi-4-multimodal-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.70",
"--disable-radix-cache",
"--max-loras-per-batch",
"2",
"--revision",
revision,
"--lora-paths",
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
def get_vision_request_kwargs(self):
return {
"extra_body": {
"lora_path": "vision",
"top_k": 1,
"top_p": 1.0,
}
}
def get_audio_request_kwargs(self):
return {
"extra_body": {
"lora_path": "speech",
"top_k": 1,
"top_p": 1.0,
}
}
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
def test_audio_ambient_completion(self):
pass
if __name__ == "__main__":
del TestOpenAIVisionServer
del (
TestOpenAIOmniServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
)
unittest.main()
......@@ -4,12 +4,11 @@ from test_vision_openai_server_common import *
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestPixtralServer(TestOpenAIVisionServer):
class TestPixtralServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "mistral-community/pixtral-12b"
......@@ -29,11 +28,8 @@ class TestPixtralServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestMistral3_1Server(TestOpenAIVisionServer):
class TestMistral3_1Server(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Mistral-Small-3.1-24B-Instruct-2503"
......@@ -53,11 +49,8 @@ class TestMistral3_1Server(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestDeepseekVL2Server(TestOpenAIVisionServer):
class TestDeepseekVL2Server(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/deepseek-vl2-small"
......@@ -77,11 +70,8 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestJanusProServer(TestOpenAIVisionServer):
class TestJanusProServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/Janus-Pro-7B"
......@@ -104,10 +94,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
def test_video_images_chat_completion(self):
pass
def test_single_image_chat_completion(self):
# Skip this test because it is flaky
pass
## Skip for ci test
# class TestLlama4Server(TestOpenAIVisionServer):
......@@ -135,11 +121,8 @@ class TestJanusProServer(TestOpenAIVisionServer):
# )
# cls.base_url += "/v1"
# def test_video_chat_completion(self):
# pass
class TestGemma3itServer(TestOpenAIVisionServer):
class TestGemma3itServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "google/gemma-3-4b-it"
......@@ -160,11 +143,8 @@ class TestGemma3itServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestGemma3nServer(TestOpenAIVisionServer):
class TestGemma3nServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "google/gemma-3n-E4B-it"
......@@ -184,16 +164,15 @@ class TestGemma3nServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
# self._test_audio_ambient_completion()
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
def test_audio_ambient_completion(self):
pass
def _test_mixed_image_audio_chat_completion(self):
self._test_mixed_image_audio_chat_completion()
class TestQwen2AudioServer(TestOpenAIVisionServer):
class TestQwen2AudioServer(AudioOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-Audio-7B-Instruct"
......@@ -211,36 +190,8 @@ class TestQwen2AudioServer(TestOpenAIVisionServer):
)
cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
self._test_audio_ambient_completion()
# Qwen2Audio does not support image
def test_single_image_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_turn_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_video_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_regex(self):
pass
# Qwen2Audio does not support image
def test_mixed_batch(self):
pass
class TestKimiVLServer(TestOpenAIVisionServer):
class TestKimiVLServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
......@@ -266,91 +217,6 @@ class TestKimiVLServer(TestOpenAIVisionServer):
pass
class TestPhi4MMServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
from huggingface_hub import constants, snapshot_download
snapshot_download(
"microsoft/Phi-4-multimodal-instruct",
allow_patterns=["**/adapter_config.json"],
)
cls.model = "microsoft/Phi-4-multimodal-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.70",
"--disable-radix-cache",
"--max-loras-per-batch",
"2",
"--revision",
revision,
"--lora-paths",
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
def get_vision_request_kwargs(self):
return {
"extra_body": {
"lora_path": "vision",
"top_k": 1,
"top_p": 1.0,
}
}
def get_audio_request_kwargs(self):
return {
"extra_body": {
"lora_path": "speech",
"top_k": 1,
"top_p": 1.0,
}
}
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
# self._test_audio_ambient_completion()
class TestVILAServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--trust-remote-code",
"--context-length=65536",
f"--revision={cls.revision}",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
# Skip for ci test
# class TestGLM41VServer(TestOpenAIVisionServer):
# @classmethod
......@@ -379,5 +245,10 @@ class TestVILAServer(TestOpenAIVisionServer):
if __name__ == "__main__":
del TestOpenAIVisionServer
del (
TestOpenAIOmniServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
)
unittest.main()
import base64
import io
import json
import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import openai
......@@ -10,12 +8,7 @@ import requests
from PIL import Image
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase
# image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
......@@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(CustomTestCase):
class TestOpenAIOmniServerBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
cls.model = ""
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.process = None
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_vision_request_kwargs(self):
return self.get_request_kwargs()
def get_request_kwargs(self):
return {}
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......@@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase):
return messages
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
# this test samples frames of video as input, but not video directly
def test_video_images_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
......@@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
def _test_video_chat_completion(self):
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
......@@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase):
), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{"""
+ r""""color":"[\w]+","""
+ r""""number_of_cars":[\d]+"""
+ r"""\}"""
)
extra_kwargs = self.get_vision_request_kwargs()
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in the JSON format.",
},
],
},
],
temperature=0,
**extra_kwargs,
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["color"], str)
assert isinstance(js_obj["number_of_cars"], int)
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": content},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def _test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def _test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
def test_audio_chat_completion(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment