Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37c8a576
Unverified
Commit
37c8a576
authored
Nov 27, 2024
by
Ying Sheng
Committed by
GitHub
Nov 27, 2024
Browse files
[feat] Support session control for vision language models (#2210)
parent
c754652f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
265 additions
and
21 deletions
+265
-21
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+6
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+37
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+15
-4
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+7
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_session_control.py
test/srt/test_session_control.py
+196
-4
No files found.
python/sglang/srt/managers/image_processor.py
View file @
37c8a576
...
...
@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if
not
image_data
:
return
None
modalities
=
request_obj
.
modalities
or
[
"image"
]
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
...
...
@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor):
else
None
)
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
# Multiple imag
es
if
len
(
image_data
)
>
1
:
if
"multi-images"
in
modalities
or
"video"
in
modaliti
es
:
# Multiple images
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values
,
image_hashes
,
image_sizes
=
[],
[],
[]
res
=
[]
...
...
@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor):
)
image_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
elif
isinstance
(
image_data
,
str
):
# A single image
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
,
aspect_ratio
,
grid_pinpoints
)
image_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
37c8a576
...
...
@@ -31,6 +31,7 @@ import dataclasses
import
logging
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -167,6 +168,30 @@ class ImageInputs:
return
ret
def
merge
(
self
,
other
,
vocab_size
):
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
self
.
image_hashes
+=
other
.
image_hashes
self
.
pad_values
=
[
(
self
.
image_hashes
)
%
vocab_size
,
(
self
.
image_hashes
>>
16
)
%
vocab_size
,
(
self
.
image_hashes
>>
32
)
%
vocab_size
,
(
self
.
image_hashes
>>
64
)
%
vocab_size
,
]
optional_args
=
[
"image_sizes"
,
"image_offsets"
,
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
]
for
arg
in
optional_args
:
if
getattr
(
self
,
arg
,
None
)
is
not
None
:
setattr
(
self
,
arg
,
getattr
(
self
,
arg
)
+
getattr
(
other
,
arg
))
class
Req
:
"""The input and output status of a request."""
...
...
@@ -177,6 +202,7 @@ class Req:
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
origin_input_ids_unpadded
:
Optional
[
Tuple
[
int
]]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
...
...
@@ -184,7 +210,11 @@ class Req:
# Input and output info
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids_unpadded
=
(
origin_input_ids_unpadded
if
origin_input_ids_unpadded
else
origin_input_ids
# Before image padding
)
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
...
...
@@ -260,6 +290,12 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self
.
cached_tokens
=
0
def
extend_image_inputs
(
self
,
image_inputs
,
vocab_size
):
if
self
.
image_inputs
is
None
:
self
.
image_inputs
=
image_inputs
else
:
self
.
image_inputs
.
merge
(
image_inputs
,
vocab_size
)
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
37c8a576
...
...
@@ -559,12 +559,13 @@ class Scheduler:
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
_unpadded
,
req
.
image_inputs
req
.
origin_input_ids
,
image_inputs
)
req
.
extend_image_inputs
(
image_inputs
,
self
.
model_config
.
vocab_size
)
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
req
.
finished_reason
=
FINISH_ABORT
(
...
...
python/sglang/srt/managers/session_controller.py
View file @
37c8a576
...
...
@@ -41,16 +41,27 @@ class Session:
]
+
req
.
input_ids
)
input_ids_unpadded
=
(
self
.
reqs
[
-
1
].
origin_input_ids_unpadded
+
self
.
reqs
[
-
1
].
output_ids
[
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
)
else
:
input_ids
=
req
.
input_ids
input_ids_unpadded
=
req
.
input_ids
new_req
=
Req
(
req
.
rid
,
None
,
input_ids
,
req
.
sampling_params
,
rid
=
req
.
rid
,
origin_input_text
=
None
,
origin_input_ids
=
input_ids
,
origin_input_ids_unpadded
=
input_ids_unpadded
,
sampling_params
=
req
.
sampling_params
,
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
)
if
len
(
self
.
reqs
)
>
0
:
new_req
.
image_inputs
=
self
.
reqs
[
-
1
].
image_inputs
new_req
.
tokenizer
=
tokenizer
if
req
.
session_rid
is
not
None
and
len
(
self
.
reqs
)
==
0
:
new_req
.
finished_reason
=
FINISH_ABORT
(
...
...
python/sglang/srt/models/llava.py
View file @
37c8a576
...
...
@@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
# hardcode for spatial_unpad + anyres
image_aspect_ratio
=
"anyres"
if
len
(
image_sizes
)
==
1
else
"pad"
if
image_inputs
.
modalities
is
not
None
and
(
"multi-images"
in
image_inputs
.
modalities
or
"video"
in
image_inputs
.
modalities
):
image_aspect_ratio
=
"pad"
else
:
image_aspect_ratio
=
"anyres"
offset_list
=
[]
for
image_s
in
image_sizes
:
if
len
(
image_sizes
)
>
16
:
...
...
test/srt/run_suite.py
View file @
37c8a576
...
...
@@ -36,6 +36,7 @@ suites = {
"test_triton_attention_backend.py"
,
"test_update_weights.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
],
"sampling/penaltylib"
:
glob
.
glob
(
"sampling/penaltylib/**/test_*.py"
,
recursive
=
True
...
...
test/srt/test_session_control.py
View file @
37c8a576
"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control
_vlm
python3 -m unittest test_session_control.TestSessionControl
Vision
.test_session_control
"""
import
unittest
...
...
@@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
...
...
@@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
...
...
@@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
...
...
@@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
...
...
@@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
if
i
>
0
:
input_ids
+=
tokenizer
.
encode
(
response
[
"text"
])[
1
:
]
# drop the bos token
output_ids
=
tokenizer
.
encode
(
response
[
"text"
])
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
output_ids
=
output_ids
[
1
:]
input_ids
+=
output_ids
outputs_normal
.
append
(
response
[
"text"
])
if
i
==
0
:
input_ids_first_req
=
input_ids
.
copy
()
...
...
@@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
outputs_normal
.
append
(
response
[
"text"
])
print
(
"outputs from chunked queries with session control:"
)
print
(
outputs_from_session
)
print
(
"outputs from normal queries:"
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
class
TestSessionControlVision
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"lmms-lab/llava-onevision-qwen2-7b-ov"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
# other_args={"--disable-radix"},
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_session_control
(
self
):
text_chunks
=
[
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
,
"<|im_start|>user
\n
<image>
\n
Describe this image in a very short sentence.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with the previous image? Answer yes or no.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with the previous image? Answer yes or no.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
]
image_chunks
=
[
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
]
assert
len
(
text_chunks
)
==
len
(
image_chunks
)
+
1
tokenizer
=
get_tokenizer
(
self
.
model
)
text_input_ids
=
[
tokenizer
.
encode
(
x
)
for
x
in
text_chunks
]
# 1. using session control
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
).
json
()
rid
=
None
first_rid
=
None
outputs_from_session
=
[]
for
i
in
range
(
len
(
text_input_ids
)):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
i
],
"image_data"
:
image_chunks
[
i
-
1
]
if
i
>
0
else
None
,
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
rid
=
response
[
"meta_info"
][
"id"
]
if
i
==
0
:
first_rid
=
rid
if
i
>
0
:
outputs_from_session
.
append
(
response
[
"text"
])
# backtrack to the first request and regenerate
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
first_rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
outputs_from_session
.
append
(
response
[
"text"
])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
ret
=
requests
.
post
(
self
.
base_url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
assert
ret
.
status_code
==
200
# send a request to a closed session, should see abort
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
# 2. not use session control
input_ids_first_req
=
None
input_ids
=
[]
outputs_normal
=
[]
for
i
in
range
(
len
(
text_input_ids
)):
input_ids
+=
text_input_ids
[
i
]
image_data
=
image_chunks
[:
i
]
if
i
>
0
else
None
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids
,
"image_data"
:
image_data
,
"modalities"
:
[
"multi-images"
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
if
i
>
0
:
output_ids
=
tokenizer
.
encode
(
response
[
"text"
])
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
output_ids
=
output_ids
[
1
:]
input_ids
+=
output_ids
outputs_normal
.
append
(
response
[
"text"
])
if
i
==
0
:
input_ids_first_req
=
input_ids
.
copy
()
input_ids_first_req
+=
text_input_ids
[
-
1
]
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids_first_req
,
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment