Unverified Commit d09a51f1 authored by XinyuanTong's avatar XinyuanTong Committed by GitHub
Browse files

[feat&refactor] Enhance multimodal input support with refactor io_struct (#4938)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent f8194b26
...@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union ...@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from PIL.Image import Image
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -135,9 +136,19 @@ class Engine: ...@@ -135,9 +136,19 @@ class Engine:
sampling_params: Optional[Union[List[Dict], Dict]] = None, sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None, input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # Can be formatted as:
image_data: Optional[Union[List[str], str]] = None, # - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
...@@ -190,9 +201,19 @@ class Engine: ...@@ -190,9 +201,19 @@ class Engine:
sampling_params: Optional[Union[List[Dict], Dict]] = None, sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None, input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # Can be formatted as:
image_data: Optional[Union[List[str], str]] = None, # - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
...@@ -228,7 +249,13 @@ class Engine: ...@@ -228,7 +249,13 @@ class Engine:
def encode( def encode(
self, self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[Union[List[str], str]] = None, image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
) -> Dict: ) -> Dict:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
......
...@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor
...@@ -56,9 +57,19 @@ class VerlEngine: ...@@ -56,9 +57,19 @@ class VerlEngine:
sampling_params: Optional[Union[List[Dict], Dict]] = None, sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None, input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # Can be formatted as:
image_data: Optional[Union[List[str], str]] = None, # - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
......
...@@ -20,7 +20,13 @@ import copy ...@@ -20,7 +20,13 @@ import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
# handle serialization of Image for pydantic
if TYPE_CHECKING:
from PIL.Image import Image
else:
Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -42,10 +48,16 @@ class GenerateReqInput: ...@@ -42,10 +48,16 @@ class GenerateReqInput:
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds. # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # Can be formatted as:
image_data: Optional[Union[List[str], str]] = None # - Single image for a single request
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string. # - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
...@@ -84,6 +96,31 @@ class GenerateReqInput: ...@@ -84,6 +96,31 @@ class GenerateReqInput:
return_hidden_states: bool = False return_hidden_states: bool = False
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
"""
Normalize the batch size and arguments for the request.
This method resolves various input formats and ensures all parameters
are properly formatted as either single values or batches depending on the input.
It also handles parallel sampling expansion and sets default values for
unspecified parameters.
Raises:
ValueError: If inputs are not properly specified (e.g., none or all of
text, input_ids, input_embeds are provided)
"""
self._validate_inputs()
self._determine_batch_size()
self._handle_parallel_sampling()
if self.is_single:
self._normalize_single_inputs()
else:
self._normalize_batch_inputs()
self._validate_session_params()
def _validate_inputs(self):
"""Validate that the input configuration is valid."""
if ( if (
self.text is None and self.input_ids is None and self.input_embeds is None self.text is None and self.input_ids is None and self.input_embeds is None
) or ( ) or (
...@@ -95,7 +132,8 @@ class GenerateReqInput: ...@@ -95,7 +132,8 @@ class GenerateReqInput:
"Either text, input_ids or input_embeds should be provided." "Either text, input_ids or input_embeds should be provided."
) )
# Derive the batch size def _determine_batch_size(self):
"""Determine if this is a single example or a batch and the batch size."""
if self.text is not None: if self.text is not None:
if isinstance(self.text, str): if isinstance(self.text, str):
self.is_single = True self.is_single = True
...@@ -119,21 +157,25 @@ class GenerateReqInput: ...@@ -119,21 +157,25 @@ class GenerateReqInput:
self.is_single = True self.is_single = True
self.batch_size = 1 self.batch_size = 1
else: else:
self.is_single = False
self.batch_size = len(self.input_embeds) self.batch_size = len(self.input_embeds)
# Handle parallel sampling def _handle_parallel_sampling(self):
# When parallel sampling is used, we always treat the input as a batch. """Handle parallel sampling parameters and adjust batch size if needed."""
# Determine parallel sample count
if self.sampling_params is None: if self.sampling_params is None:
self.parallel_sample_num = 1 self.parallel_sample_num = 1
elif isinstance(self.sampling_params, dict): elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1) self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list): else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1) self.parallel_sample_num = self.sampling_params[0].get("n", 1)
assert all( for sampling_params in self.sampling_params:
self.parallel_sample_num == sampling_params.get("n", 1) if self.parallel_sample_num != sampling_params.get("n", 1):
for sampling_params in self.sampling_params raise ValueError(
), "The parallel_sample_num should be the same for all samples in sample params." "The parallel_sample_num should be the same for all samples in sample params."
)
# If using parallel sampling with a single example, convert to batch
if self.parallel_sample_num > 1 and self.is_single: if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False self.is_single = False
if self.text is not None: if self.text is not None:
...@@ -141,8 +183,8 @@ class GenerateReqInput: ...@@ -141,8 +183,8 @@ class GenerateReqInput:
if self.input_ids is not None: if self.input_ids is not None:
self.input_ids = [self.input_ids] self.input_ids = [self.input_ids]
# Fill in default arguments def _normalize_single_inputs(self):
if self.is_single: """Normalize inputs for a single example."""
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = {} self.sampling_params = {}
if self.rid is None: if self.rid is None:
...@@ -155,58 +197,142 @@ class GenerateReqInput: ...@@ -155,58 +197,142 @@ class GenerateReqInput:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and [] if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None self.token_ids_logprob = None
else:
def _normalize_batch_inputs(self):
"""Normalize inputs for a batch of examples, including parallel sampling expansion."""
# Calculate expanded batch size
if self.parallel_sample_num == 1: if self.parallel_sample_num == 1:
num = self.batch_size num = self.batch_size
else: else:
# Expand parallel_sample_num # Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num num = self.batch_size * self.parallel_sample_num
if not self.image_data: # Expand input based on type
self._expand_inputs(num)
self._normalize_lora_paths(num)
self._normalize_image_data(num)
self._normalize_audio_data(num)
self._normalize_sampling_params(num)
self._normalize_rid(num)
self._normalize_logprob_params(num)
self._normalize_custom_logit_processor(num)
def _expand_inputs(self, num):
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
if self.text is not None:
if not isinstance(self.text, list):
raise ValueError("Text should be a list for batch processing.")
self.text = self.text * self.parallel_sample_num
elif self.input_ids is not None:
if not isinstance(self.input_ids, list) or not isinstance(
self.input_ids[0], list
):
raise ValueError(
"input_ids should be a list of lists for batch processing."
)
self.input_ids = self.input_ids * self.parallel_sample_num
elif self.input_embeds is not None:
if not isinstance(self.input_embeds, list):
raise ValueError("input_embeds should be a list for batch processing.")
self.input_embeds = self.input_embeds * self.parallel_sample_num
def _normalize_lora_paths(self, num):
"""Normalize LoRA paths for batch processing."""
if self.lora_path is not None:
if isinstance(self.lora_path, str):
self.lora_path = [self.lora_path] * num
elif isinstance(self.lora_path, list):
self.lora_path = self.lora_path * self.parallel_sample_num
else:
raise ValueError("lora_path should be a list or a string.")
def _normalize_image_data(self, num):
"""Normalize image data for batch processing."""
if self.image_data is None:
self.image_data = [None] * num self.image_data = [None] * num
elif not isinstance(self.image_data, list): elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num # Single image, convert to list of single-image lists
self.image_data = [[self.image_data]] * num
self.modalities = ["image"] * num
elif isinstance(self.image_data, list): elif isinstance(self.image_data, list):
pass if len(self.image_data) != self.batch_size:
raise ValueError(
"The length of image_data should be equal to the batch size."
)
self.modalities = []
if len(self.image_data) > 0 and isinstance(self.image_data[0], list):
# Already a list of lists, keep as is
for i in range(len(self.image_data)):
if self.image_data[i] is None or self.image_data[i] == [None]:
self.modalities.append(None)
elif len(self.image_data[i]) == 1:
self.modalities.append("image")
elif len(self.image_data[i]) > 1:
self.modalities.append("multi-images")
# Expand parallel_sample_num
self.image_data = self.image_data * self.parallel_sample_num
self.modalities = self.modalities * self.parallel_sample_num
else:
# List of images for a batch, wrap each in a list
wrapped_images = [[img] for img in self.image_data]
# Expand for parallel sampling
self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num
def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing."""
if self.audio_data is None: if self.audio_data is None:
self.audio_data = [None] * num self.audio_data = [None] * num
elif not isinstance(self.audio_data, list): elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list): elif isinstance(self.audio_data, list):
pass self.audio_data = self.audio_data * self.parallel_sample_num
def _normalize_sampling_params(self, num):
"""Normalize sampling parameters for batch processing."""
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = [{}] * num self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list): elif isinstance(self.sampling_params, dict):
self.sampling_params = [self.sampling_params] * num self.sampling_params = [self.sampling_params] * num
else: # Already a list
self.sampling_params = self.sampling_params * self.parallel_sample_num
def _normalize_rid(self, num):
"""Normalize request IDs for batch processing."""
if self.rid is None: if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)] self.rid = [uuid.uuid4().hex for _ in range(num)]
elif not isinstance(self.rid, list):
raise ValueError("The rid should be a list for batch processing.")
def _normalize_logprob_params(self, num):
"""Normalize logprob-related parameters for batch processing."""
# Helper function to normalize a parameter
def normalize_param(param, default_value, param_name):
if param is None:
return [default_value] * num
elif not isinstance(param, list):
return [param] * num
else: else:
assert isinstance(self.rid, list), "The rid should be a list." if self.parallel_sample_num > 1:
raise ValueError(
if self.return_logprob is None: f"Cannot use list {param_name} with parallel_sample_num > 1"
self.return_logprob = [False] * num )
elif not isinstance(self.return_logprob, list): return param
self.return_logprob = [self.return_logprob] * num
else:
assert self.parallel_sample_num == 1
if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1
if self.top_logprobs_num is None: # Normalize each logprob parameter
self.top_logprobs_num = [0] * num self.return_logprob = normalize_param(
elif not isinstance(self.top_logprobs_num, list): self.return_logprob, False, "return_logprob"
self.top_logprobs_num = [self.top_logprobs_num] * num )
else: self.logprob_start_len = normalize_param(
assert self.parallel_sample_num == 1 self.logprob_start_len, -1, "logprob_start_len"
)
self.top_logprobs_num = normalize_param(
self.top_logprobs_num, 0, "top_logprobs_num"
)
# Handle token_ids_logprob specially due to its nested structure
if not self.token_ids_logprob: # covers both None and [] if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list): elif not isinstance(self.token_ids_logprob, list):
...@@ -215,23 +341,32 @@ class GenerateReqInput: ...@@ -215,23 +341,32 @@ class GenerateReqInput:
self.token_ids_logprob = [ self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num) copy.deepcopy(self.token_ids_logprob) for _ in range(num)
] ]
else: elif self.parallel_sample_num > 1:
assert self.parallel_sample_num == 1 raise ValueError(
"Cannot use list token_ids_logprob with parallel_sample_num > 1"
)
def _normalize_custom_logit_processor(self, num):
"""Normalize custom logit processor for batch processing."""
if self.custom_logit_processor is None: if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list): elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num self.custom_logit_processor = [self.custom_logit_processor] * num
else: elif self.parallel_sample_num > 1:
assert self.parallel_sample_num == 1 raise ValueError(
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
)
# Other checks def _validate_session_params(self):
"""Validate that session parameters are properly formatted."""
if self.session_params is not None: if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance( if not isinstance(self.session_params, dict) and not isinstance(
self.session_params[0], dict self.session_params[0], dict
) ):
raise ValueError("Session params must be a dict or a list of dicts.")
def regenerate_rid(self): def regenerate_rid(self):
"""Generate a new request ID and return it."""
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
return self.rid return self.rid
...@@ -305,8 +440,15 @@ class TokenizedGenerateReqInput: ...@@ -305,8 +440,15 @@ class TokenizedGenerateReqInput:
class EmbeddingReqInput: class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None text: Optional[Union[List[str], str]] = None
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be an image instance, file name, URL, or base64 encoded string.
image_data: Optional[Union[List[str], str]] = None # Can be formatted as:
# - Single image for a single request
# - List of images (one per request in a batch)
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id. # The request id.
......
import copy
import unittest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
)
class TestGenerateReqInputNormalization(CustomTestCase):
"""Test the normalization of GenerateReqInput for batch processing and different input formats."""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
def setUp(self):
# Common setup for all tests
self.base_req = GenerateReqInput(
text=["Hello", "World"],
sampling_params=[{}, {}],
rid=["id1", "id2"],
)
def test_single_image_to_list_of_lists(self):
"""Test that a single image is converted to a list of single-image lists."""
req = copy.deepcopy(self.base_req)
req.image_data = "single_image.jpg" # A single image (non-list)
req.normalize_batch_and_arguments()
# Should be converted to [[image], [image]]
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
self.assertEqual(req.image_data[0][0], "single_image.jpg")
self.assertEqual(req.image_data[1][0], "single_image.jpg")
# Check modalities
self.assertEqual(req.modalities, ["image", "image"])
def test_list_of_images_to_list_of_lists(self):
"""Test that a list of images is converted to a list of single-image lists."""
req = copy.deepcopy(self.base_req)
req.image_data = ["image1.jpg", "image2.jpg"] # List of images
req.normalize_batch_and_arguments()
# Should be converted to [[image1], [image2]]
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
self.assertEqual(req.image_data[0][0], "image1.jpg")
self.assertEqual(req.image_data[1][0], "image2.jpg")
# Check modalities
self.assertEqual(req.modalities, ["image", "image"])
def test_list_of_lists_with_different_modalities(self):
"""Test handling of list of lists of images with different modalities."""
req = copy.deepcopy(self.base_req)
req.image_data = [
["image1.jpg"], # Single image (image modality)
["image2.jpg", "image3.jpg"], # Multiple images (multi-images modality)
]
req.normalize_batch_and_arguments()
# Structure should remain the same
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 2)
# Check modalities
self.assertEqual(req.modalities, ["image", "multi-images"])
def test_list_of_lists_with_none_values(self):
"""Test handling of list of lists with None values."""
req = copy.deepcopy(self.base_req)
req.image_data = [
[None], # None value
["image.jpg"], # Single image
]
req.normalize_batch_and_arguments()
# Structure should remain the same
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
# Check modalities
self.assertEqual(req.modalities, [None, "image"])
def test_expanding_parallel_sample_correlation(self):
"""Test that when expanding with parallel samples, prompts, images and modalities are properly correlated."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg", "image3.jpg"],
]
req.sampling_params = {"n": 3} # All prompts get 3 samples
# Define expected values before normalization
expected_text = req.text * 3
expected_images = req.image_data * 3
expected_modalities = ["image", "multi-images"] * 3
req.normalize_batch_and_arguments()
# Should be expanded to 6 items (2 original * 3 parallel)
self.assertEqual(len(req.image_data), 6)
# Check that images are properly expanded
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Ensure that text items are properly duplicated too
self.assertEqual(req.text, expected_text)
def test_list_of_lists_with_none_values(self):
"""Test handling of list of lists with None values."""
req = copy.deepcopy(self.base_req)
req.image_data = [
[None], # None value
["image.jpg"], # Single image
]
req.normalize_batch_and_arguments()
# Structure should remain the same
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
# Check modalities
self.assertEqual(req.modalities, [None, "image"])
def test_specific_parallel_n_per_sample(self):
"""Test parallel expansion when different samples have different n values."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg", "image3.jpg"],
]
req.sampling_params = [
{"n": 2},
{"n": 2},
] # First prompt gets 2 samples, second prompt gets 2 samples
expected_images = req.image_data * 2
expected_modalities = ["image", "multi-images"] * 2
expected_text = req.text * 2
req.normalize_batch_and_arguments()
# Should be expanded to 4 items (2 original * 2 parallel)
self.assertEqual(len(req.image_data), 4)
# Check that the first 2 are copies for the first prompt
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Check text expansion
self.assertEqual(req.text, expected_text)
def test_mixed_none_and_images_with_parallel_samples(self):
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
req.image_data = [
["image1.jpg"],
None,
["image3_1.jpg", "image3_2.jpg"],
]
req.sampling_params = {"n": 2} # All prompts get 2 samples
expected_images = req.image_data * 2
expected_modalities = ["image", None, "multi-images"] * 2
expected_text = req.text * 2
req.normalize_batch_and_arguments()
# Should be expanded to 6 items (3 original * 2 parallel)
self.assertEqual(len(req.image_data), 6)
# Check image data
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Check text expansion
self.assertEqual(req.text, expected_text)
def test_correlation_with_sampling_params(self):
"""Test that sampling parameters are correctly correlated with prompts during expansion."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg"],
]
req.sampling_params = [
{"temperature": 0.7, "n": 2},
{"temperature": 0.9, "n": 2},
]
req.normalize_batch_and_arguments()
# Check sampling params expansion
self.assertEqual(len(req.sampling_params), 4)
self.assertEqual(req.sampling_params[0]["temperature"], 0.7)
self.assertEqual(req.sampling_params[1]["temperature"], 0.9)
self.assertEqual(req.sampling_params[2]["temperature"], 0.7)
self.assertEqual(req.sampling_params[3]["temperature"], 0.9)
# Should be expanded to 4 items (2 original * 2 parallel)
self.assertEqual(len(req.image_data), 4)
# Check correlation with images
self.assertEqual(req.image_data[0], ["image1.jpg"])
self.assertEqual(req.image_data[1], ["image2.jpg"])
self.assertEqual(req.image_data[2], ["image1.jpg"])
self.assertEqual(req.image_data[3], ["image2.jpg"])
def test_single_example_with_image(self):
"""Test handling of single example with image."""
req = GenerateReqInput(
text="Hello",
image_data="single_image.jpg",
)
req.normalize_batch_and_arguments()
# For single examples, image_data doesn't get processed into lists
self.assertEqual(req.image_data, "single_image.jpg")
self.assertIsNone(req.modalities) # Modalities isn't set for single examples
def test_single_to_batch_with_parallel_sampling(self):
"""Test single example converted to batch with parallel sampling."""
req = GenerateReqInput(
text="Hello",
image_data="single_image.jpg",
sampling_params={"n": 3}, # parallel_sample_num = 3
)
# Define expected values before normalization
expected_text = ["Hello"] * 3
req.normalize_batch_and_arguments()
# Should be converted to batch with text=["Hello"]
self.assertEqual(req.text, expected_text)
# Image should be automatically wrapped to list of lists with length 1*3=3
self.assertEqual(len(req.image_data), 3)
self.assertEqual(req.image_data[0][0], "single_image.jpg")
self.assertEqual(req.image_data[1][0], "single_image.jpg")
self.assertEqual(req.image_data[2][0], "single_image.jpg")
# Modalities should be set for all 3 examples
self.assertEqual(req.modalities, ["image", "image", "image"])
def test_audio_data_handling(self):
"""Test handling of audio_data."""
req = copy.deepcopy(self.base_req)
req.audio_data = "audio.mp3" # Single audio
req.normalize_batch_and_arguments()
# Should be converted to ["audio.mp3", "audio.mp3"]
self.assertEqual(len(req.audio_data), 2)
self.assertEqual(req.audio_data[0], "audio.mp3")
self.assertEqual(req.audio_data[1], "audio.mp3")
# Test with list
req = copy.deepcopy(self.base_req)
req.audio_data = ["audio1.mp3", "audio2.mp3"]
req.normalize_batch_and_arguments()
# Should remain the same
self.assertEqual(len(req.audio_data), 2)
self.assertEqual(req.audio_data[0], "audio1.mp3")
self.assertEqual(req.audio_data[1], "audio2.mp3")
def test_input_ids_normalization(self):
"""Test normalization of input_ids instead of text."""
# Test single input_ids
req = GenerateReqInput(input_ids=[1, 2, 3])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
self.assertEqual(req.batch_size, 1)
# Test batch input_ids
req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]])
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2)
# Test with parallel sampling
req = GenerateReqInput(
input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2}
)
req.normalize_batch_and_arguments()
self.assertEqual(len(req.input_ids), 4) # 2 original * 2 parallel
def test_input_embeds_normalization(self):
"""Test normalization of input_embeds."""
# Test single input_embeds
req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
self.assertEqual(req.batch_size, 1)
# Test batch input_embeds
req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]])
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2)
def test_lora_path_normalization(self):
"""Test normalization of lora_path."""
# Test single lora_path with batch input
req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora")
# Define expected lora_paths before normalization
expected_lora_paths = ["path/to/lora", "path/to/lora"]
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
# Test list of lora_paths
req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"])
# Define expected lora_paths before normalization
expected_lora_paths = ["path1", "path2"]
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
# Test with parallel sampling
req = GenerateReqInput(
text=["Hello", "World"],
lora_path=["path1", "path2"],
sampling_params={"n": 2},
)
# Define expected lora_paths before normalization
expected_lora_paths = ["path1", "path2"] * 2
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
def test_logprob_parameters_normalization(self):
"""Test normalization of logprob-related parameters."""
# Test single example
req = GenerateReqInput(
text="Hello",
return_logprob=True,
logprob_start_len=10,
top_logprobs_num=5,
token_ids_logprob=[7, 8, 9],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, True)
self.assertEqual(req.logprob_start_len, 10)
self.assertEqual(req.top_logprobs_num, 5)
self.assertEqual(req.token_ids_logprob, [7, 8, 9])
# Test batch with scalar values
req = GenerateReqInput(
text=["Hello", "World"],
return_logprob=True,
logprob_start_len=10,
top_logprobs_num=5,
token_ids_logprob=[7, 8, 9],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, True])
self.assertEqual(req.logprob_start_len, [10, 10])
self.assertEqual(req.top_logprobs_num, [5, 5])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]])
# Test batch with list values
req = GenerateReqInput(
text=["Hello", "World"],
return_logprob=[True, False],
logprob_start_len=[10, 5],
top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, False])
self.assertEqual(req.logprob_start_len, [10, 5])
self.assertEqual(req.top_logprobs_num, [5, 3])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
def test_custom_logit_processor_normalization(self):
"""Test normalization of custom_logit_processor."""
# Test single processor
req = GenerateReqInput(
text=["Hello", "World"], custom_logit_processor="serialized_processor"
)
req.normalize_batch_and_arguments()
self.assertEqual(
req.custom_logit_processor, ["serialized_processor", "serialized_processor"]
)
# Test list of processors
req = GenerateReqInput(
text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"]
)
req.normalize_batch_and_arguments()
self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"])
def test_session_params_handling(self):
"""Test handling of session_params."""
# Test with dict
req = GenerateReqInput(
text=["Hello", "World"], session_params={"id": "session1", "offset": 10}
)
req.normalize_batch_and_arguments()
self.assertEqual(req.session_params, {"id": "session1", "offset": 10})
# Test with list of dicts
req = GenerateReqInput(
text=["Hello", "World"],
session_params=[{"id": "session1"}, {"id": "session2"}],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}])
def test_getitem_method(self):
"""Test the __getitem__ method."""
req = GenerateReqInput(
text=["Hello", "World"],
image_data=[["img1.jpg"], ["img2.jpg"]],
audio_data=["audio1.mp3", "audio2.mp3"],
sampling_params=[{"temp": 0.7}, {"temp": 0.8}],
rid=["id1", "id2"],
return_logprob=[True, False],
logprob_start_len=[10, 5],
top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
stream=True,
log_metrics=True,
modalities=["image", "image"],
lora_path=["path1", "path2"],
custom_logit_processor=["processor1", "processor2"],
return_hidden_states=True,
)
req.normalize_batch_and_arguments()
# Get the first item
item0 = req[0]
self.assertEqual(item0.text, "Hello")
self.assertEqual(item0.image_data, ["img1.jpg"])
self.assertEqual(item0.audio_data, "audio1.mp3")
self.assertEqual(item0.sampling_params, {"temp": 0.7})
self.assertEqual(item0.rid, "id1")
self.assertEqual(item0.return_logprob, True)
self.assertEqual(item0.logprob_start_len, 10)
self.assertEqual(item0.top_logprobs_num, 5)
self.assertEqual(item0.token_ids_logprob, [7, 8, 9])
self.assertEqual(item0.stream, True)
self.assertEqual(item0.log_metrics, True)
self.assertEqual(item0.modalities, "image")
self.assertEqual(item0.lora_path, "path1")
self.assertEqual(item0.custom_logit_processor, "processor1")
self.assertEqual(item0.return_hidden_states, True)
def test_regenerate_rid(self):
"""Test the regenerate_rid method."""
req = GenerateReqInput(text="Hello")
req.normalize_batch_and_arguments()
original_rid = req.rid
new_rid = req.regenerate_rid()
self.assertNotEqual(original_rid, new_rid)
self.assertEqual(req.rid, new_rid)
def test_error_cases(self):
"""Test various error cases."""
# Test when neither text, input_ids, nor input_embeds is provided
with self.assertRaises(ValueError):
req = GenerateReqInput()
req.normalize_batch_and_arguments()
# Test when all of text, input_ids, and input_embeds are provided
with self.assertRaises(ValueError):
req = GenerateReqInput(
text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]]
)
req.normalize_batch_and_arguments()
def test_multiple_input_formats(self):
"""Test different combinations of input formats."""
# Test with text only
req = GenerateReqInput(text="Hello")
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
# Test with input_ids only
req = GenerateReqInput(input_ids=[1, 2, 3])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
# Test with input_embeds only
req = GenerateReqInput(input_embeds=[[0.1, 0.2]])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment