Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a8aad935
Unverified
Commit
a8aad935
authored
Nov 11, 2024
by
yizhang2077
Committed by
GitHub
Nov 10, 2024
Browse files
qwen2vl fix bug for #1971 #1897 (#1984)
parent
47ffe7af
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
14 deletions
+8
-14
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-9
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+7
-3
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+0
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
a8aad935
...
...
@@ -133,6 +133,7 @@ class ImageInputs:
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
@
staticmethod
def
from_dict
(
obj
,
vocab_size
):
...
...
@@ -251,9 +252,6 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self
.
cached_tokens
=
0
# For Qwen2-VL
self
.
mrope_position_delta
=
[]
# use mutable object
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
...
...
@@ -983,8 +981,6 @@ class ScheduleBatch:
global
bid
bid
+=
1
mrope_positions_delta
=
[
req
.
mrope_position_delta
for
req
in
self
.
reqs
]
return
ModelWorkerBatch
(
bid
=
bid
,
forward_mode
=
self
.
forward_mode
,
...
...
@@ -1007,7 +1003,6 @@ class ScheduleBatch:
encoder_out_cache_loc
=
self
.
encoder_out_cache_loc
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
)
def
copy
(
self
):
...
...
@@ -1074,9 +1069,6 @@ class ModelWorkerBatch:
# Sampling info
sampling_info
:
SamplingBatchInfo
# For Qwen2-VL
mrope_positions_delta
:
List
[
List
[
int
]]
def
copy
(
self
):
return
dataclasses
.
replace
(
self
,
sampling_info
=
self
.
sampling_info
.
copy
())
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
a8aad935
...
...
@@ -136,8 +136,13 @@ class ForwardBatch:
mrope_positions_list
=
[
None
]
*
self
.
seq_lens
.
shape
[
0
]
if
self
.
forward_mode
.
is_decode
():
for
i
,
_
in
enumerate
(
mrope_positions_list
):
mrope_position_delta
=
(
0
if
batch
.
image_inputs
[
i
]
is
None
else
batch
.
image_inputs
[
i
].
mrope_position_delta
)
mrope_positions_list
[
i
]
=
MRotaryEmbedding
.
get_next_input_positions
(
batch
.
mrope_position
s
_delta
[
i
][
0
]
,
mrope_position_delta
,
int
(
self
.
seq_lens
[
i
])
-
1
,
int
(
self
.
seq_lens
[
i
]),
)
...
...
@@ -159,7 +164,6 @@ class ForwardBatch:
)
]
]
*
3
mrope_position_delta
=
0
else
:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
...
...
@@ -173,8 +177,8 @@ class ForwardBatch:
context_len
=
0
,
)
)
batch
.
image_inputs
[
i
].
mrope_position_delta
=
mrope_position_delta
mrope_positions_list
[
i
]
=
mrope_positions
batch
.
mrope_positions_delta
[
i
].
append
(
mrope_position_delta
)
self
.
mrope_positions
=
torch
.
concat
(
[
torch
.
tensor
(
pos
,
device
=
device
)
for
pos
in
mrope_positions_list
],
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
a8aad935
...
...
@@ -649,8 +649,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
]
image_embeds_offset
+=
num_image_tokens
input_ids
=
None
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment