Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aa47f642
You need to sign in or sign up before continuing.
Unverified
Commit
aa47f642
authored
Dec 02, 2024
by
Ying Sheng
Committed by
GitHub
Dec 02, 2024
Browse files
Revert "[feat] Enable chunked prefill for llava-onevision" (#2329)
parent
3ddb1c46
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
221 deletions
+18
-221
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-9
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+14
-37
test/srt/run_suite.py
test/srt/run_suite.py
+0
-1
test/srt/test_vision_chunked_prefill.py
test/srt/test_vision_chunked_prefill.py
+0
-173
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
aa47f642
...
@@ -128,7 +128,6 @@ class ImageInputs:
...
@@ -128,7 +128,6 @@ class ImageInputs:
image_hashes
:
Optional
[
list
]
=
None
image_hashes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
aa47f642
...
@@ -111,20 +111,15 @@ class ModelRunner:
...
@@ -111,20 +111,15 @@ class ModelRunner:
)
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
-
1
self
.
mem_fraction_static
*=
0.95
self
.
mem_fraction_static
*=
0.95
if
self
.
model_config
.
hf_config
.
architectures
==
[
"MllamaForConditionalGeneration"
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
"Qwen2VLForConditionalGeneration"
]:
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
server_args
.
disable_radix_cache
=
True
# Global vars
# Global vars
...
...
python/sglang/srt/models/llava.py
View file @
aa47f642
...
@@ -57,7 +57,6 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -57,7 +57,6 @@ class LlavaBaseForCausalLM(nn.Module):
else
:
else
:
image_aspect_ratio
=
"anyres"
image_aspect_ratio
=
"anyres"
offset_list
=
[]
offset_list
=
[]
image_inputs
.
image_pad_len
=
[]
for
image_idx
,
image_s
in
enumerate
(
image_sizes
):
for
image_idx
,
image_s
in
enumerate
(
image_sizes
):
if
len
(
image_sizes
)
>
16
:
if
len
(
image_sizes
)
>
16
:
# 2x2 pooling with stride 2
# 2x2 pooling with stride 2
...
@@ -104,7 +103,6 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -104,7 +103,6 @@ class LlavaBaseForCausalLM(nn.Module):
+
input_ids
[
offset
+
1
:]
+
input_ids
[
offset
+
1
:]
)
)
offset_list
.
append
(
offset
)
offset_list
.
append
(
offset
)
image_inputs
.
image_pad_len
.
append
(
new_image_feature_len
)
image_inputs
.
image_offsets
=
offset_list
image_inputs
.
image_offsets
=
offset_list
return
input_ids
return
input_ids
...
@@ -136,14 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -136,14 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs
=
forward_batch
.
image_inputs
image_inputs
=
forward_batch
.
image_inputs
if
forward_batch
.
forward_mode
.
is_extend
():
if
forward_batch
.
forward_mode
.
is_extend
():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Got List[List[str]] extend it to List[str]
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
# The length of the List should be equal to batch size
modalities_list
=
[]
modalities_list
=
[]
...
@@ -152,12 +142,18 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -152,12 +142,18 @@ class LlavaBaseForCausalLM(nn.Module):
if
im
and
im
.
modalities
is
not
None
:
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
im
.
modalities
)
modalities_list
.
extend
(
im
.
modalities
)
if
im
and
im
.
image_offsets
:
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
np
.
max
(
np
.
array
(
im
.
image_offsets
)
+
np
.
array
(
im
.
image_pad_len
))
)
else
:
else
:
max_image_offset
.
append
(
-
1
)
max_image_offset
.
append
(
-
1
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
...
@@ -354,7 +350,6 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -354,7 +350,6 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image
# Fill in the placeholder for the image
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
pt
=
0
pt
=
0
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
...
@@ -362,36 +357,18 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -362,36 +357,18 @@ class LlavaBaseForCausalLM(nn.Module):
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
seq_len
=
extend_seq_lens
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
# Multiple images
for
image_idx
,
image_offset
in
enumerate
(
for
j
,
image_offset
in
enumerate
(
image_inputs
[
i
].
image_offsets
):
image_inputs
[
i
].
image_offsets
if
image_offset
<
prefix_len
:
):
if
(
image_offset
+
image_inputs
[
i
].
image_pad_len
[
image_idx
]
<=
prefix_len
):
continue
continue
if
image_offset
>=
prefix_len
+
seq_len
:
break
tmp_image_feature
=
image_features
[
pt
][
image_idx
]
tmp_image_feature
=
image_features
[
pt
][
j
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
input_offset
=
image_offset
-
prefix_len
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
left_idx
=
start_idx
+
input_offset
right_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
+
pad_len
right_idx
=
left_idx
+
pad_len
assert
right_idx
>
start_idx
if
input_offset
<
0
:
left_idx
=
start_idx
tmp_image_feature
=
tmp_image_feature
[
-
input_offset
:]
if
right_idx
>
start_idx
+
seq_len
:
tmp_image_feature
=
tmp_image_feature
[
:
start_idx
+
seq_len
-
right_idx
]
right_idx
=
start_idx
+
seq_len
try
:
try
:
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
...
...
test/srt/run_suite.py
View file @
aa47f642
...
@@ -39,7 +39,6 @@ suites = {
...
@@ -39,7 +39,6 @@ suites = {
"test_triton_attention_kernels.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
"test_triton_attention_backend.py"
,
"test_update_weights_from_disk.py"
,
"test_update_weights_from_disk.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
"test_session_control.py"
,
],
],
...
...
test/srt/test_vision_chunked_prefill.py
deleted
100644 → 0
View file @
3ddb1c46
"""
Usage:
python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill
"""
import
base64
import
io
import
os
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Union
import
numpy
as
np
import
requests
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestVisionChunkedPrefill
(
unittest
.
TestCase
):
def
prepare_video_messages
(
self
,
video_path
,
max_frames_num
=
8
):
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
total_frame_num
=
len
(
vr
)
uniform_sampled_frames
=
np
.
linspace
(
0
,
total_frame_num
-
1
,
max_frames_num
,
dtype
=
int
)
frame_idx
=
uniform_sampled_frames
.
tolist
()
frames
=
vr
.
get_batch
(
frame_idx
).
asnumpy
()
base64_frames
=
[]
for
frame
in
frames
:
pil_img
=
Image
.
fromarray
(
frame
)
buff
=
io
.
BytesIO
()
pil_img
.
save
(
buff
,
format
=
"JPEG"
)
base64_str
=
base64
.
b64encode
(
buff
.
getvalue
()).
decode
(
"utf-8"
)
base64_frames
.
append
(
base64_str
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[]}]
frame_format
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,{}"
},
"modalities"
:
"video"
,
}
for
base64_frame
in
base64_frames
:
frame_format
[
"image_url"
][
"url"
]
=
"data:image/jpeg;base64,{}"
.
format
(
base64_frame
)
messages
[
0
][
"content"
].
append
(
frame_format
.
copy
())
prompt
=
{
"type"
:
"text"
,
"text"
:
"Please describe the video briefly."
}
messages
[
0
][
"content"
].
append
(
prompt
)
return
messages
def
get_prompt_from_messages
(
self
,
messages
):
text
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
"
)
image_data
=
[]
for
content
in
messages
[
0
][
"content"
]:
if
content
[
"type"
]
==
"image_url"
:
text
+=
"<image>
\n
"
image_data
.
append
(
content
[
"image_url"
][
"url"
])
text
+=
"Please describe the video briefly.<|im_end|>
\n
<|im_start|>assistant
\n
"
return
text
,
image_data
def
generate
(
self
,
text
,
image_data
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
text
,
"image_data"
:
image_data
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
"modalities"
:
[
"multi-images"
],
},
).
json
()
return
response
[
"text"
]
def
generate_for_video
(
self
,
batch
,
num_frame
)
->
Union
[
str
,
list
[
str
]]:
# prepare the video input about Steven introducing ipod nano
url
=
"https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache"
)
file_path
=
os
.
path
.
join
(
cache_dir
,
"jobs.mp4"
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
file_path
):
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
response
.
content
)
if
not
batch
:
assert
isinstance
(
num_frame
,
int
)
messages
=
self
.
prepare_video_messages
(
file_path
,
max_frames_num
=
num_frame
)
text
,
image_data
=
self
.
get_prompt_from_messages
(
messages
)
return
self
.
generate
(
text
,
image_data
)
else
:
assert
isinstance
(
num_frame
,
list
)
func_args
=
[]
for
max_frames_num
in
num_frame
:
messages
=
self
.
prepare_video_messages
(
file_path
,
max_frames_num
=
max_frames_num
,
)
text
,
image_data
=
self
.
get_prompt_from_messages
(
messages
)
func_args
.
append
((
text
,
image_data
))
with
ThreadPoolExecutor
(
max_workers
=
10
)
as
executor
:
responses
=
list
(
executor
.
map
(
lambda
p
:
self
.
generate
(
*
p
),
func_args
))
return
responses
def
run_generate
(
self
,
chunked_prefill_size
,
batch
,
num_frame
):
# launch server
model
=
"lmms-lab/llava-onevision-qwen2-7b-ov"
# model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
self
.
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
f
"
{
chunked_prefill_size
}
"
,
],
)
try
:
return
self
.
generate_for_video
(
batch
,
num_frame
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_chunked_prefill
(
self
):
output_chunked
=
self
.
run_generate
(
chunked_prefill_size
=
1024
,
batch
=
False
,
num_frame
=
1
)
output_no_chunked
=
self
.
run_generate
(
chunked_prefill_size
=-
1
,
batch
=
False
,
num_frame
=
1
)
print
(
"output with chunked prefill:"
)
print
(
output_chunked
)
print
(
"output without chunked prefill:"
)
print
(
output_no_chunked
)
assert
output_chunked
==
output_no_chunked
output_chunked
=
self
.
run_generate
(
chunked_prefill_size
=
1024
,
batch
=
True
,
num_frame
=
[
2
,
6
,
8
,
10
]
)
output_no_chunked
=
self
.
run_generate
(
chunked_prefill_size
=-
1
,
batch
=
True
,
num_frame
=
[
2
,
6
,
8
,
10
]
)
print
(
"output with chunked prefill:"
)
print
(
output_chunked
)
print
(
"output without chunked prefill:"
)
print
(
output_no_chunked
)
assert
output_chunked
==
output_no_chunked
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment