"docs/vscode:/vscode.git/clone" did not exist on "bbcf2a8589f93acd401bd9e6367add6412eabc04"
Unverified Commit d17d19e5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix mixed batch for multi modal models (#1702)

parent dd3809fa
......@@ -76,7 +76,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 16
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......@@ -96,7 +96,7 @@ jobs:
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite minimal --range-begin 16
python3 run_suite.py --suite minimal --range-begin 17
performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
......@@ -160,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]
########## Encode Image ########
......@@ -358,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
prefix_len = prefix_lens_cpu[i]
# Multiple images
for j, image_offset in enumerate(image_offsets[i]):
for j, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset < prefix_len:
continue
......
"""
Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
"""
import base64
import io
import json
import os
import unittest
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import openai
......@@ -288,6 +294,55 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert isinstance(js_obj["color"], str)
assert isinstance(js_obj["number_of_cars"], int)
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a very short sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": content},
],
temperature=0,
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment