Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7038fec
Unverified
Commit
b7038fec
authored
Nov 28, 2024
by
Ying Sheng
Committed by
GitHub
Nov 28, 2024
Browse files
[fix] Fix prefix caching for multi-image/video (#2239)
parent
65fdb289
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
22 deletions
+26
-22
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+21
-15
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-6
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+1
-1
test/srt/test_session_control.py
test/srt/test_session_control.py
+2
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b7038fec
...
...
@@ -145,15 +145,17 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
image_hashes
=
hash
(
tuple
(
obj
[
"image_hashes"
]
))
,
image_hashes
=
obj
[
"image_hashes"
],
)
image_hash
=
ret
.
image_hashes
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
if
not
isinstance
(
ret
.
image_hashes
,
list
):
ret
.
pad_values
=
[
(
ret
.
image_hashes
)
%
vocab_size
,
(
ret
.
image_hashes
>>
16
)
%
vocab_size
,
(
ret
.
image_hashes
>>
32
)
%
vocab_size
,
(
ret
.
image_hashes
>>
64
)
%
vocab_size
,
]
else
:
ret
.
pad_values
=
[
x
%
vocab_size
for
x
in
ret
.
image_hashes
]
optional_args
=
[
"image_sizes"
,
...
...
@@ -171,14 +173,18 @@ class ImageInputs:
def
merge
(
self
,
other
,
vocab_size
):
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
self
.
image_hashes
+=
other
.
image_hashes
self
.
pad_values
=
[
(
self
.
image_hashes
)
%
vocab_size
,
(
self
.
image_hashes
>>
16
)
%
vocab_size
,
(
self
.
image_hashes
>>
32
)
%
vocab_size
,
(
self
.
image_hashes
>>
64
)
%
vocab_size
,
]
if
isinstance
(
self
.
image_hashes
,
list
)
and
isinstance
(
other
.
image_hashes
,
list
):
self
.
image_hashes
+=
other
.
image_hashes
self
.
pad_values
=
[
x
%
vocab_size
for
x
in
self
.
image_hashes
]
else
:
self
.
image_hashes
=
hash
(
tuple
(
self
.
image_hashes
,
other
.
image_hashes
))
self
.
pad_values
=
[
(
self
.
image_hashes
)
%
vocab_size
,
(
self
.
image_hashes
>>
16
)
%
vocab_size
,
(
self
.
image_hashes
>>
32
)
%
vocab_size
,
(
self
.
image_hashes
>>
64
)
%
vocab_size
,
]
optional_args
=
[
"image_sizes"
,
...
...
python/sglang/srt/models/llava.py
View file @
b7038fec
...
...
@@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
else
:
image_aspect_ratio
=
"anyres"
offset_list
=
[]
for
image_
s
in
image_sizes
:
for
image_
idx
,
image_s
in
enumerate
(
image_sizes
)
:
if
len
(
image_sizes
)
>
16
:
# 2x2 pooling with stride 2
new_image_feature_len
=
(
...
...
@@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
new_w
=
int
(
new_w
//
times
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_values
*
(
(
new_image_feature_len
+
len
(
pad_values
))
//
len
(
pad_values
)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try
:
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
except
ValueError
:
...
...
@@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids
=
(
input_ids
[:
offset
]
+
pad_
ids
[:
new_image_feature_len
]
+
[
pad_
values
[
image_idx
]]
*
new_image_feature_len
+
input_ids
[
offset
+
1
:]
)
offset_list
.
append
(
offset
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
b7038fec
...
...
@@ -500,7 +500,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return
num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# a
n
d replaced padding by unique image hash
# a
d
d replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
image_grid_thws
=
image_inputs
.
image_grid_thws
pad_values
=
image_inputs
.
pad_values
...
...
test/srt/test_session_control.py
View file @
b7038fec
...
...
@@ -301,6 +301,8 @@ class TestSessionControlVision(unittest.TestCase):
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
# 2. not use session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
input_ids_first_req
=
None
input_ids
=
[]
outputs_normal
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment