Unverified Commit 1df84ff4 authored by Mick's avatar Mick Committed by GitHub
Browse files

ci: simplify multi-modality tests by using mixins (#9006)

parent 66d6be08
...@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
if videos: if videos:
kwargs["videos"] = videos kwargs["videos"] = videos
if audios: if audios:
if self.arch in { if self._processor.__class__.__name__ in {
"Gemma3nForConditionalGeneration", "Gemma3nProcessor",
"Qwen2AudioForConditionalGeneration", "Qwen2AudioProcessor",
}: }:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios kwargs["audio"] = audios
......
...@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM ...@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image, logger from sglang.srt.utils import ImageData, load_image, logger
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
@staticmethod @staticmethod
def _process_single_image_task( def _process_single_image_task(
image_data: Union[str, bytes], image_data: Union[str, bytes, ImageData],
image_aspect_ratio: Optional[str] = None, image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None, image_grid_pinpoints: Optional[str] = None,
processor=None, processor=None,
...@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_processor = processor.image_processor image_processor = processor.image_processor
try: try:
image, image_size = load_image(image_data) url = image_data.url if isinstance(image_data, ImageData) else image_data
image, image_size = load_image(url)
if image_size is not None: if image_size is not None:
# It is a video with multiple images # It is a video with multiple images
image_hash = hash(image_data) image_hash = hash(url)
pixel_values = image_processor(image)["pixel_values"] pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)): for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16) pixel_values[_] = pixel_values[_].astype(np.float16)
...@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
return pixel_values, image_hash, image_size return pixel_values, image_hash, image_size
else: else:
# It is an image # It is an image
image_hash = hash(image_data) image_hash = hash(url)
if image_aspect_ratio == "pad": if image_aspect_ratio == "pad":
image = expand2square( image = expand2square(
image, image,
...@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image( async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str self,
image_data: Union[bytes, str, ImageData],
aspect_ratio: str,
grid_pinpoints: str,
): ):
if self.cpu_executor is not None: if self.cpu_executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
...@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes, ImageData]],
input_text, input_text,
request_obj, request_obj,
*args, *args,
......
...@@ -110,8 +110,8 @@ suites = { ...@@ -110,8 +110,8 @@ suites = {
TestFile("test_utils_update_weights.py", 48), TestFile("test_utils_update_weights.py", 48),
TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vision_chunked_prefill.py", 175),
TestFile("test_vlm_input_format.py", 300), TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 989), TestFile("test_vision_openai_server_a.py", 403),
TestFile("test_vision_openai_server_b.py", 620), TestFile("test_vision_openai_server_b.py", 446),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
......
...@@ -8,16 +8,28 @@ import unittest ...@@ -8,16 +8,28 @@ import unittest
from test_vision_openai_server_common import * from test_vision_openai_server_common import *
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
class TestQwen2VLServer(TestOpenAIVisionServer): class TestLlava(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1"
class TestQwen2VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct" cls.model = "Qwen/Qwen2-VL-7B-Instruct"
...@@ -37,11 +49,8 @@ class TestQwen2VLServer(TestOpenAIVisionServer): ...@@ -37,11 +49,8 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestQwen2_5_VLServer(TestOpenAIVisionServer): class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "Qwen/Qwen2.5-VL-7B-Instruct" cls.model = "Qwen/Qwen2.5-VL-7B-Instruct"
...@@ -61,9 +70,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): ...@@ -61,9 +70,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestVLMContextLengthIssue(CustomTestCase): class TestVLMContextLengthIssue(CustomTestCase):
@classmethod @classmethod
...@@ -137,11 +143,8 @@ class TestVLMContextLengthIssue(CustomTestCase): ...@@ -137,11 +143,8 @@ class TestVLMContextLengthIssue(CustomTestCase):
# ) # )
# cls.base_url += "/v1" # cls.base_url += "/v1"
# def test_video_chat_completion(self):
# pass
class TestMinicpmvServer(ImageOpenAITestMixin):
class TestMinicpmvServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "openbmb/MiniCPM-V-2_6" cls.model = "openbmb/MiniCPM-V-2_6"
...@@ -162,7 +165,7 @@ class TestMinicpmvServer(TestOpenAIVisionServer): ...@@ -162,7 +165,7 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestInternVL2_5Server(TestOpenAIVisionServer): class TestInternVL2_5Server(ImageOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "OpenGVLab/InternVL2_5-2B" cls.model = "OpenGVLab/InternVL2_5-2B"
...@@ -181,7 +184,7 @@ class TestInternVL2_5Server(TestOpenAIVisionServer): ...@@ -181,7 +184,7 @@ class TestInternVL2_5Server(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestMinicpmoServer(TestOpenAIVisionServer): class TestMinicpmoServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "openbmb/MiniCPM-o-2_6" cls.model = "openbmb/MiniCPM-o-2_6"
...@@ -201,12 +204,8 @@ class TestMinicpmoServer(TestOpenAIVisionServer): ...@@ -201,12 +204,8 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
self._test_audio_ambient_completion()
class TestMimoVLServer(TestOpenAIVisionServer): class TestMimoVLServer(ImageOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "XiaomiMiMo/MiMo-VL-7B-RL" cls.model = "XiaomiMiMo/MiMo-VL-7B-RL"
...@@ -228,6 +227,95 @@ class TestMimoVLServer(TestOpenAIVisionServer): ...@@ -228,6 +227,95 @@ class TestMimoVLServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestVILAServer(ImageOpenAITestMixin):
@classmethod
def setUpClass(cls):
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--trust-remote-code",
"--context-length=65536",
f"--revision={cls.revision}",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod
def setUpClass(cls):
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
from huggingface_hub import constants, snapshot_download
snapshot_download(
"microsoft/Phi-4-multimodal-instruct",
allow_patterns=["**/adapter_config.json"],
)
cls.model = "microsoft/Phi-4-multimodal-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.70",
"--disable-radix-cache",
"--max-loras-per-batch",
"2",
"--revision",
revision,
"--lora-paths",
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
def get_vision_request_kwargs(self):
return {
"extra_body": {
"lora_path": "vision",
"top_k": 1,
"top_p": 1.0,
}
}
def get_audio_request_kwargs(self):
return {
"extra_body": {
"lora_path": "speech",
"top_k": 1,
"top_p": 1.0,
}
}
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
def test_audio_ambient_completion(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
del TestOpenAIVisionServer del (
TestOpenAIOmniServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
)
unittest.main() unittest.main()
...@@ -4,12 +4,11 @@ from test_vision_openai_server_common import * ...@@ -4,12 +4,11 @@ from test_vision_openai_server_common import *
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
) )
class TestPixtralServer(TestOpenAIVisionServer): class TestPixtralServer(ImageOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "mistral-community/pixtral-12b" cls.model = "mistral-community/pixtral-12b"
...@@ -29,11 +28,8 @@ class TestPixtralServer(TestOpenAIVisionServer): ...@@ -29,11 +28,8 @@ class TestPixtralServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestMistral3_1Server(TestOpenAIVisionServer): class TestMistral3_1Server(ImageOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "unsloth/Mistral-Small-3.1-24B-Instruct-2503" cls.model = "unsloth/Mistral-Small-3.1-24B-Instruct-2503"
...@@ -53,11 +49,8 @@ class TestMistral3_1Server(TestOpenAIVisionServer): ...@@ -53,11 +49,8 @@ class TestMistral3_1Server(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestDeepseekVL2Server(ImageOpenAITestMixin):
class TestDeepseekVL2Server(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "deepseek-ai/deepseek-vl2-small" cls.model = "deepseek-ai/deepseek-vl2-small"
...@@ -77,11 +70,8 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer): ...@@ -77,11 +70,8 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestJanusProServer(ImageOpenAITestMixin):
class TestJanusProServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "deepseek-ai/Janus-Pro-7B" cls.model = "deepseek-ai/Janus-Pro-7B"
...@@ -104,10 +94,6 @@ class TestJanusProServer(TestOpenAIVisionServer): ...@@ -104,10 +94,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
def test_video_images_chat_completion(self): def test_video_images_chat_completion(self):
pass pass
def test_single_image_chat_completion(self):
# Skip this test because it is flaky
pass
## Skip for ci test ## Skip for ci test
# class TestLlama4Server(TestOpenAIVisionServer): # class TestLlama4Server(TestOpenAIVisionServer):
...@@ -135,11 +121,8 @@ class TestJanusProServer(TestOpenAIVisionServer): ...@@ -135,11 +121,8 @@ class TestJanusProServer(TestOpenAIVisionServer):
# ) # )
# cls.base_url += "/v1" # cls.base_url += "/v1"
# def test_video_chat_completion(self):
# pass
class TestGemma3itServer(TestOpenAIVisionServer): class TestGemma3itServer(ImageOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "google/gemma-3-4b-it" cls.model = "google/gemma-3-4b-it"
...@@ -160,11 +143,8 @@ class TestGemma3itServer(TestOpenAIVisionServer): ...@@ -160,11 +143,8 @@ class TestGemma3itServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
class TestGemma3nServer(TestOpenAIVisionServer): class TestGemma3nServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "google/gemma-3n-E4B-it" cls.model = "google/gemma-3n-E4B-it"
...@@ -184,16 +164,15 @@ class TestGemma3nServer(TestOpenAIVisionServer): ...@@ -184,16 +164,15 @@ class TestGemma3nServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_audio_chat_completion(self): # This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
self._test_audio_speech_completion() def test_audio_ambient_completion(self):
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM pass
# self._test_audio_ambient_completion()
def _test_mixed_image_audio_chat_completion(self): def _test_mixed_image_audio_chat_completion(self):
self._test_mixed_image_audio_chat_completion() self._test_mixed_image_audio_chat_completion()
class TestQwen2AudioServer(TestOpenAIVisionServer): class TestQwen2AudioServer(AudioOpenAITestMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "Qwen/Qwen2-Audio-7B-Instruct" cls.model = "Qwen/Qwen2-Audio-7B-Instruct"
...@@ -211,36 +190,8 @@ class TestQwen2AudioServer(TestOpenAIVisionServer): ...@@ -211,36 +190,8 @@ class TestQwen2AudioServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
self._test_audio_ambient_completion()
# Qwen2Audio does not support image class TestKimiVLServer(ImageOpenAITestMixin):
def test_single_image_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_turn_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_multi_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_video_images_chat_completion(self):
pass
# Qwen2Audio does not support image
def test_regex(self):
pass
# Qwen2Audio does not support image
def test_mixed_batch(self):
pass
class TestKimiVLServer(TestOpenAIVisionServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "moonshotai/Kimi-VL-A3B-Instruct" cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
...@@ -266,91 +217,6 @@ class TestKimiVLServer(TestOpenAIVisionServer): ...@@ -266,91 +217,6 @@ class TestKimiVLServer(TestOpenAIVisionServer):
pass pass
class TestPhi4MMServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
from huggingface_hub import constants, snapshot_download
snapshot_download(
"microsoft/Phi-4-multimodal-instruct",
allow_patterns=["**/adapter_config.json"],
)
cls.model = "microsoft/Phi-4-multimodal-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--mem-fraction-static",
"0.70",
"--disable-radix-cache",
"--max-loras-per-batch",
"2",
"--revision",
revision,
"--lora-paths",
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
def get_vision_request_kwargs(self):
return {
"extra_body": {
"lora_path": "vision",
"top_k": 1,
"top_p": 1.0,
}
}
def get_audio_request_kwargs(self):
return {
"extra_body": {
"lora_path": "speech",
"top_k": 1,
"top_p": 1.0,
}
}
def test_audio_chat_completion(self):
self._test_audio_speech_completion()
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
# self._test_audio_ambient_completion()
class TestVILAServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--trust-remote-code",
"--context-length=65536",
f"--revision={cls.revision}",
"--cuda-graph-max-bs",
"4",
],
)
cls.base_url += "/v1"
# Skip for ci test # Skip for ci test
# class TestGLM41VServer(TestOpenAIVisionServer): # class TestGLM41VServer(TestOpenAIVisionServer):
# @classmethod # @classmethod
...@@ -379,5 +245,10 @@ class TestVILAServer(TestOpenAIVisionServer): ...@@ -379,5 +245,10 @@ class TestVILAServer(TestOpenAIVisionServer):
if __name__ == "__main__": if __name__ == "__main__":
del TestOpenAIVisionServer del (
TestOpenAIOmniServerBase,
ImageOpenAITestMixin,
VideoOpenAITestMixin,
AudioOpenAITestMixin,
)
unittest.main() unittest.main()
import base64 import base64
import io import io
import json
import os import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
import openai import openai
...@@ -10,12 +8,7 @@ import requests ...@@ -10,12 +8,7 @@ import requests
from PIL import Image from PIL import Image
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
# image # image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png" IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
...@@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test ...@@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(CustomTestCase): class TestOpenAIOmniServerBase(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" cls.model = ""
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456" cls.api_key = "sk-123456"
cls.process = popen_launch_server( cls.process = None
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1" cls.base_url += "/v1"
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_vision_request_kwargs(self): def get_vision_request_kwargs(self):
return self.get_request_kwargs() return self.get_request_kwargs()
def get_request_kwargs(self): def get_request_kwargs(self):
return {} return {}
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
def test_single_image_chat_completion(self): def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase):
return messages return messages
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
# this test samples frames of video as input, but not video directly
def test_video_images_chat_completion(self): def test_video_images_chat_completion(self):
url = VIDEO_JOBS_URL url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url) file_path = self.get_or_download_file(url)
...@@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
def _test_video_chat_completion(self):
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url) file_path = self.get_or_download_file(url)
...@@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase):
), f"video_response: {video_response}, should contain 'black' or 'dark'" ), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{"""
+ r""""color":"[\w]+","""
+ r""""number_of_cars":[\d]+"""
+ r"""\}"""
)
extra_kwargs = self.get_vision_request_kwargs()
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in the JSON format.",
},
],
},
],
temperature=0,
**extra_kwargs,
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["color"], str)
assert isinstance(js_obj["number_of_cars"], int)
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": content},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def _test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: |{audio_response}| should contain |{check_word}|"
def _test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
def test_audio_chat_completion(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment