Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
480e38a7
"tests/python/common/test_readout.py" did not exist on "c3516f1a8e68504d4dd0bec32887f8af079965d7"
Unverified
Commit
480e38a7
authored
Dec 02, 2024
by
Ying Sheng
Committed by
GitHub
Dec 02, 2024
Browse files
[feat] Enable chunked prefill for llava-onevision (#2281)
parent
69e2d4fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
221 additions
and
18 deletions
+221
-18
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-4
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+37
-14
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_vision_chunked_prefill.py
test/srt/test_vision_chunked_prefill.py
+173
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
480e38a7
...
@@ -128,6 +128,7 @@ class ImageInputs:
...
@@ -128,6 +128,7 @@ class ImageInputs:
image_hashes
:
Optional
[
list
]
=
None
image_hashes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
480e38a7
...
@@ -110,15 +110,20 @@ class ModelRunner:
...
@@ -110,15 +110,20 @@ class ModelRunner:
)
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
-
1
self
.
mem_fraction_static
*=
0.95
self
.
mem_fraction_static
*=
0.95
if
self
.
model_config
.
hf_config
.
architectures
==
[
"MllamaForConditionalGeneration"
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
"Qwen2VLForConditionalGeneration"
]:
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
server_args
.
disable_radix_cache
=
True
# Global vars
# Global vars
...
...
python/sglang/srt/models/llava.py
View file @
480e38a7
...
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
else
:
else
:
image_aspect_ratio
=
"anyres"
image_aspect_ratio
=
"anyres"
offset_list
=
[]
offset_list
=
[]
image_inputs
.
image_pad_len
=
[]
for
image_idx
,
image_s
in
enumerate
(
image_sizes
):
for
image_idx
,
image_s
in
enumerate
(
image_sizes
):
if
len
(
image_sizes
)
>
16
:
if
len
(
image_sizes
)
>
16
:
# 2x2 pooling with stride 2
# 2x2 pooling with stride 2
...
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
+
input_ids
[
offset
+
1
:]
+
input_ids
[
offset
+
1
:]
)
)
offset_list
.
append
(
offset
)
offset_list
.
append
(
offset
)
image_inputs
.
image_pad_len
.
append
(
new_image_feature_len
)
image_inputs
.
image_offsets
=
offset_list
image_inputs
.
image_offsets
=
offset_list
return
input_ids
return
input_ids
...
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs
=
forward_batch
.
image_inputs
image_inputs
=
forward_batch
.
image_inputs
if
forward_batch
.
forward_mode
.
is_extend
():
if
forward_batch
.
forward_mode
.
is_extend
():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Got List[List[str]] extend it to List[str]
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
# The length of the List should be equal to batch size
modalities_list
=
[]
modalities_list
=
[]
...
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
if
im
and
im
.
modalities
is
not
None
:
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
im
.
modalities
)
modalities_list
.
extend
(
im
.
modalities
)
if
im
and
im
.
image_offsets
:
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
max_image_offset
.
append
(
np
.
max
(
np
.
array
(
im
.
image_offsets
)
+
np
.
array
(
im
.
image_pad_len
))
)
else
:
else
:
max_image_offset
.
append
(
-
1
)
max_image_offset
.
append
(
-
1
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
...
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image
# Fill in the placeholder for the image
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
pt
=
0
pt
=
0
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
...
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
seq_len
=
extend_seq_lens
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
# Multiple images
for
j
,
image_offset
in
enumerate
(
image_inputs
[
i
].
image_offsets
):
for
image_idx
,
image_offset
in
enumerate
(
if
image_offset
<
prefix_len
:
image_inputs
[
i
].
image_offsets
):
if
(
image_offset
+
image_inputs
[
i
].
image_pad_len
[
image_idx
]
<=
prefix_len
):
continue
continue
if
image_offset
>=
prefix_len
+
seq_len
:
break
tmp_image_feature
=
image_features
[
pt
][
j
]
tmp_image_feature
=
image_features
[
pt
][
image_idx
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
input_offset
=
image_offset
-
prefix_len
right_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
+
pad_len
left_idx
=
start_idx
+
input_offset
right_idx
=
left_idx
+
pad_len
assert
right_idx
>
start_idx
if
input_offset
<
0
:
left_idx
=
start_idx
tmp_image_feature
=
tmp_image_feature
[
-
input_offset
:]
if
right_idx
>
start_idx
+
seq_len
:
tmp_image_feature
=
tmp_image_feature
[
:
start_idx
+
seq_len
-
right_idx
]
right_idx
=
start_idx
+
seq_len
try
:
try
:
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
...
...
test/srt/run_suite.py
View file @
480e38a7
...
@@ -39,6 +39,7 @@ suites = {
...
@@ -39,6 +39,7 @@ suites = {
"test_triton_attention_kernels.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
"test_triton_attention_backend.py"
,
"test_update_weights_from_disk.py"
,
"test_update_weights_from_disk.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
"test_session_control.py"
,
],
],
...
...
test/srt/test_vision_chunked_prefill.py
0 → 100644
View file @
480e38a7
"""
Usage:
python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill
"""
import
base64
import
io
import
os
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Union
import
numpy
as
np
import
requests
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestVisionChunkedPrefill
(
unittest
.
TestCase
):
def
prepare_video_messages
(
self
,
video_path
,
max_frames_num
=
8
):
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
total_frame_num
=
len
(
vr
)
uniform_sampled_frames
=
np
.
linspace
(
0
,
total_frame_num
-
1
,
max_frames_num
,
dtype
=
int
)
frame_idx
=
uniform_sampled_frames
.
tolist
()
frames
=
vr
.
get_batch
(
frame_idx
).
asnumpy
()
base64_frames
=
[]
for
frame
in
frames
:
pil_img
=
Image
.
fromarray
(
frame
)
buff
=
io
.
BytesIO
()
pil_img
.
save
(
buff
,
format
=
"JPEG"
)
base64_str
=
base64
.
b64encode
(
buff
.
getvalue
()).
decode
(
"utf-8"
)
base64_frames
.
append
(
base64_str
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[]}]
frame_format
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,{}"
},
"modalities"
:
"video"
,
}
for
base64_frame
in
base64_frames
:
frame_format
[
"image_url"
][
"url"
]
=
"data:image/jpeg;base64,{}"
.
format
(
base64_frame
)
messages
[
0
][
"content"
].
append
(
frame_format
.
copy
())
prompt
=
{
"type"
:
"text"
,
"text"
:
"Please describe the video briefly."
}
messages
[
0
][
"content"
].
append
(
prompt
)
return
messages
def
get_prompt_from_messages
(
self
,
messages
):
text
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
"
)
image_data
=
[]
for
content
in
messages
[
0
][
"content"
]:
if
content
[
"type"
]
==
"image_url"
:
text
+=
"<image>
\n
"
image_data
.
append
(
content
[
"image_url"
][
"url"
])
text
+=
"Please describe the video briefly.<|im_end|>
\n
<|im_start|>assistant
\n
"
return
text
,
image_data
def
generate
(
self
,
text
,
image_data
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
text
,
"image_data"
:
image_data
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
"modalities"
:
[
"multi-images"
],
},
).
json
()
return
response
[
"text"
]
def
generate_for_video
(
self
,
batch
,
num_frame
)
->
Union
[
str
,
list
[
str
]]:
# prepare the video input about Steven introducing ipod nano
url
=
"https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache"
)
file_path
=
os
.
path
.
join
(
cache_dir
,
"jobs.mp4"
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
file_path
):
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
response
.
content
)
if
not
batch
:
assert
isinstance
(
num_frame
,
int
)
messages
=
self
.
prepare_video_messages
(
file_path
,
max_frames_num
=
num_frame
)
text
,
image_data
=
self
.
get_prompt_from_messages
(
messages
)
return
self
.
generate
(
text
,
image_data
)
else
:
assert
isinstance
(
num_frame
,
list
)
func_args
=
[]
for
max_frames_num
in
num_frame
:
messages
=
self
.
prepare_video_messages
(
file_path
,
max_frames_num
=
max_frames_num
,
)
text
,
image_data
=
self
.
get_prompt_from_messages
(
messages
)
func_args
.
append
((
text
,
image_data
))
with
ThreadPoolExecutor
(
max_workers
=
10
)
as
executor
:
responses
=
list
(
executor
.
map
(
lambda
p
:
self
.
generate
(
*
p
),
func_args
))
return
responses
def
run_generate
(
self
,
chunked_prefill_size
,
batch
,
num_frame
):
# launch server
model
=
"lmms-lab/llava-onevision-qwen2-7b-ov"
# model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
self
.
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
f
"
{
chunked_prefill_size
}
"
,
],
)
try
:
return
self
.
generate_for_video
(
batch
,
num_frame
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_chunked_prefill
(
self
):
output_chunked
=
self
.
run_generate
(
chunked_prefill_size
=
1024
,
batch
=
False
,
num_frame
=
1
)
output_no_chunked
=
self
.
run_generate
(
chunked_prefill_size
=-
1
,
batch
=
False
,
num_frame
=
1
)
print
(
"output with chunked prefill:"
)
print
(
output_chunked
)
print
(
"output without chunked prefill:"
)
print
(
output_no_chunked
)
assert
output_chunked
==
output_no_chunked
output_chunked
=
self
.
run_generate
(
chunked_prefill_size
=
1024
,
batch
=
True
,
num_frame
=
[
2
,
6
,
8
,
10
]
)
output_no_chunked
=
self
.
run_generate
(
chunked_prefill_size
=-
1
,
batch
=
True
,
num_frame
=
[
2
,
6
,
8
,
10
]
)
print
(
"output with chunked prefill:"
)
print
(
output_chunked
)
print
(
"output without chunked prefill:"
)
print
(
output_no_chunked
)
assert
output_chunked
==
output_no_chunked
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment