Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d17d19e5
Unverified
Commit
d17d19e5
authored
Oct 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 17, 2024
Browse files
Fix mixed batch for multi modal models (#1702)
parent
dd3809fa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
6 deletions
+58
-6
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+2
-2
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-4
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+55
-0
No files found.
.github/workflows/pr-test.yml
View file @
d17d19e5
...
@@ -76,7 +76,7 @@ jobs:
...
@@ -76,7 +76,7 @@ jobs:
timeout-minutes
:
20
timeout-minutes
:
20
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 1
6
python3 run_suite.py --suite minimal --range-begin 5 --range-end 1
7
unit-test-backend-part-3
:
unit-test-backend-part-3
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -96,7 +96,7 @@ jobs:
...
@@ -96,7 +96,7 @@ jobs:
timeout-minutes
:
20
timeout-minutes
:
20
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 1
6
python3 run_suite.py --suite minimal --range-begin 1
7
performance-test-1-gpu-part-1
:
performance-test-1-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
python/sglang/srt/models/llava.py
View file @
d17d19e5
...
@@ -160,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -160,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes
=
[
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
]
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
########## Encode Image ########
...
@@ -358,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -358,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
# Multiple images
for
j
,
image_offset
in
enumerate
(
image_offsets
[
i
]
):
for
j
,
image_offset
in
enumerate
(
image_
inputs
[
i
].
image_
offsets
):
if
image_offset
<
prefix_len
:
if
image_offset
<
prefix_len
:
continue
continue
...
...
test/srt/test_vision_openai_server.py
View file @
d17d19e5
"""
Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
"""
import
base64
import
base64
import
io
import
io
import
json
import
json
import
os
import
os
import
unittest
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
import
numpy
as
np
import
numpy
as
np
import
openai
import
openai
...
@@ -288,6 +294,55 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -288,6 +294,55 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert
isinstance
(
js_obj
[
"color"
],
str
)
assert
isinstance
(
js_obj
[
"color"
],
str
)
assert
isinstance
(
js_obj
[
"number_of_cars"
],
int
)
assert
isinstance
(
js_obj
[
"number_of_cars"
],
int
)
def
run_decode_with_image
(
self
,
image_id
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
content
=
[]
if
image_id
==
0
:
content
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}
)
elif
image_id
==
1
:
content
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
}
)
else
:
pass
content
.
append
(
{
"type"
:
"text"
,
"text"
:
"Describe this image in a very short sentence."
,
}
)
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
content
},
],
temperature
=
0
,
)
assert
response
.
choices
[
0
].
message
.
role
==
"assistant"
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
def
test_mixed_batch
(
self
):
image_ids
=
[
0
,
1
,
2
]
*
4
with
ThreadPoolExecutor
(
4
)
as
executor
:
list
(
executor
.
map
(
self
.
run_decode_with_image
,
image_ids
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment