Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f50a6cf4
Unverified
Commit
f50a6cf4
authored
Nov 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 29, 2024
Browse files
Fix hash collision for multi modal models (#2256)
parent
fe97a2d4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
39 deletions
+42
-39
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+18
-27
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+9
-7
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+0
-3
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+5
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+9
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
f50a6cf4
...
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
class
ImageInputs
:
class
ImageInputs
:
"""The image related inputs."""
"""The image related inputs."""
pixel_values
:
torch
.
Tensor
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
image_hashes
:
Optional
[
list
]
=
None
image_hashes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
...
@@ -132,7 +132,7 @@ class ImageInputs:
...
@@ -132,7 +132,7 @@ class ImageInputs:
modalities
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# Llava related
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
@@ -141,21 +141,17 @@ class ImageInputs:
...
@@ -141,21 +141,17 @@ class ImageInputs:
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
,
vocab_size
):
def
from_dict
(
obj
:
dict
):
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
pixel_values
=
obj
[
"pixel_values"
],
image_hashes
=
obj
[
"image_hashes"
],
image_hashes
=
obj
[
"image_hashes"
],
)
)
if
not
isinstance
(
ret
.
image_hashes
,
list
):
ret
.
pad_values
=
[
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
(
ret
.
image_hashes
)
%
vocab_size
,
# Please note that if the `input_ids` is later used in the model forward,
(
ret
.
image_hashes
>>
16
)
%
vocab_size
,
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
(
ret
.
image_hashes
>>
32
)
%
vocab_size
,
# cuda memory access.
(
ret
.
image_hashes
>>
64
)
%
vocab_size
,
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
image_hashes
]
]
else
:
ret
.
pad_values
=
[
x
%
vocab_size
for
x
in
ret
.
image_hashes
]
optional_args
=
[
optional_args
=
[
"image_sizes"
,
"image_sizes"
,
...
@@ -170,21 +166,16 @@ class ImageInputs:
...
@@ -170,21 +166,16 @@ class ImageInputs:
return
ret
return
ret
def
merge
(
self
,
other
,
vocab_size
):
def
merge
(
self
,
other
):
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
if
isinstance
(
self
.
image_hashes
,
list
)
and
isinstance
(
other
.
image_hashes
,
list
):
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
self
.
image_hashes
+=
other
.
image_hashes
# Please note that if the `input_ids` is later used in the model forward,
self
.
pad_values
=
[
x
%
vocab_size
for
x
in
self
.
image_hashes
]
# you also need to clamp the values within the range of [0, vocab_size) to avoid illegal
else
:
# cuda memory access.
self
.
image_hashes
=
hash
(
tuple
(
self
.
image_hashes
,
other
.
image_hashes
))
self
.
image_hashes
+=
other
.
image_hashes
self
.
pad_values
=
[
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
image_hashes
]
(
self
.
image_hashes
)
%
vocab_size
,
(
self
.
image_hashes
>>
16
)
%
vocab_size
,
(
self
.
image_hashes
>>
32
)
%
vocab_size
,
(
self
.
image_hashes
>>
64
)
%
vocab_size
,
]
optional_args
=
[
optional_args
=
[
"image_sizes"
,
"image_sizes"
,
...
@@ -297,11 +288,11 @@ class Req:
...
@@ -297,11 +288,11 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
# The number of cached tokens, that were already cached in the KV cache
self
.
cached_tokens
=
0
self
.
cached_tokens
=
0
def
extend_image_inputs
(
self
,
image_inputs
,
vocab_size
):
def
extend_image_inputs
(
self
,
image_inputs
):
if
self
.
image_inputs
is
None
:
if
self
.
image_inputs
is
None
:
self
.
image_inputs
=
image_inputs
self
.
image_inputs
=
image_inputs
else
:
else
:
self
.
image_inputs
.
merge
(
image_inputs
,
vocab_size
)
self
.
image_inputs
.
merge
(
image_inputs
)
# whether request reached finished condition
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
f50a6cf4
...
@@ -526,8 +526,9 @@ class Scheduler:
...
@@ -526,8 +526,9 @@ class Scheduler:
self
,
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
TokenizedGenerateReqInput
,
):
):
# Create a new request
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
# Create a new request
if
recv_req
.
input_embeds
is
not
None
:
if
recv_req
.
input_embeds
is
not
None
:
# Generate fake input_ids based on the length of input_embeds
# Generate fake input_ids based on the length of input_embeds
seq_length
=
len
(
recv_req
.
input_embeds
)
seq_length
=
len
(
recv_req
.
input_embeds
)
...
@@ -558,20 +559,20 @@ class Scheduler:
...
@@ -558,20 +559,20 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
return
return
#
I
mage inputs
#
Handle i
mage inputs
if
recv_req
.
image_inputs
is
not
None
:
if
recv_req
.
image_inputs
is
not
None
:
image_inputs
=
ImageInputs
.
from_dict
(
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
)
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
# Expand a single image token into multiple dummy tokens for receiving image embeddings
)
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
=
self
.
pad_input_ids_func
(
req
.
origin_input_ids
,
image_inputs
req
.
origin_input_ids
,
image_inputs
)
)
req
.
extend_image_inputs
(
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
extend_image_inputs
(
image_inputs
)
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
req
.
finished_reason
=
FINISH_ABORT
(
req
.
finished_reason
=
FINISH_ABORT
(
"Image request length is longer than the KV cache pool size or "
"Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds"
"the max context length. "
"Abort this request because you cannot truncate the image embeds"
)
)
req
.
image_inputs
=
None
req
.
image_inputs
=
None
req
.
origin_input_ids
=
[
0
]
req
.
origin_input_ids
=
[
0
]
...
@@ -579,6 +580,7 @@ class Scheduler:
...
@@ -579,6 +580,7 @@ class Scheduler:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
return
return
# Copy more attributes
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
...
...
python/sglang/srt/managers/session_controller.py
View file @
f50a6cf4
...
@@ -10,10 +10,7 @@
...
@@ -10,10 +10,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
copy
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f50a6cf4
...
@@ -216,6 +216,7 @@ class TokenizerManager:
...
@@ -216,6 +216,7 @@ class TokenizerManager:
input_ids
=
obj
.
input_ids
input_ids
=
obj
.
input_ids
if
self
.
is_generation
:
if
self
.
is_generation
:
# TODO: also support getting embeddings for multimodal models
image_inputs
:
Dict
=
await
self
.
image_processor
.
process_images_async
(
image_inputs
:
Dict
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
)
...
...
python/sglang/srt/models/llava.py
View file @
f50a6cf4
...
@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
else
:
else
:
max_image_offset
.
append
(
-
1
)
max_image_offset
.
append
(
-
1
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
f50a6cf4
...
@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
`None` if no images are passed.
"""
"""
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
positions
=
forward_batch
.
mrope_positions
image_inputs
=
None
image_inputs
=
None
if
forward_batch
.
image_inputs
is
not
None
:
if
forward_batch
.
image_inputs
is
not
None
:
image_inputs
=
[
image_inputs
=
[
img
for
img
in
forward_batch
.
image_inputs
if
img
is
not
None
img
for
img
in
forward_batch
.
image_inputs
if
img
is
not
None
]
]
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
positions
=
forward_batch
.
mrope_positions
if
(
if
(
forward_batch
.
forward_mode
.
is_decode
()
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
image_inputs
is
None
...
@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment