Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e53a0b3d
Unverified
Commit
e53a0b3d
authored
Apr 11, 2025
by
Mick
Committed by
GitHub
Apr 11, 2025
Browse files
[fix] fix mrope positions not picked up (#5265)
parent
038bc5d5
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
69 additions
and
69 deletions
+69
-69
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+44
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+11
-48
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+3
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+2
-0
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
e53a0b3d
...
...
@@ -94,7 +94,7 @@ class VisionAttention(nn.Module):
input_size
=
embed_dim
,
output_size
=
embed_dim
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"
out_
proj"
,
prefix
),
prefix
=
add_prefix
(
"proj"
,
prefix
),
)
def
forward
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e53a0b3d
...
...
@@ -268,6 +268,9 @@ class MultimodalDataItem:
self
.
modality
==
Modality
.
VIDEO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
is_valid
(
self
)
->
bool
:
return
self
.
is_image
()
or
self
.
is_video
()
or
self
.
is_audio
()
def
validate
(
self
):
...
# TODO
...
...
@@ -306,11 +309,7 @@ class MultimodalInputs:
)
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_audio
()
or
item
.
is_image
()
or
item
.
is_video
()
]
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_valid
()]
assert
len
(
ret
.
mm_items
)
!=
0
...
...
@@ -345,8 +344,8 @@ class MultimodalInputs:
""" """
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
def
co
llect_image
_input
s
(
self
)
->
List
[
torch
.
Tensor
]
:
return
[
item
.
pixel_val
ue
s
for
item
in
self
.
mm_items
if
item
.
is_
image
()
]
def
co
ntains_mm
_input
(
self
)
->
bool
:
return
any
(
Tr
ue
for
item
in
self
.
mm_items
if
item
.
is_
valid
()
)
def
merge
(
self
,
other
:
MultimodalInputs
):
"""
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e53a0b3d
...
...
@@ -33,7 +33,6 @@ from dataclasses import dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -399,13 +398,13 @@ class ForwardBatch:
)
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
m
ultimodal
_input
s
in
enumerate
(
batch
.
multimodal_inputs
):
for
i
,
m
m
_input
in
enumerate
(
batch
.
multimodal_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc_cpu
[
i
],
batch
.
extend_seq_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
if
m
ultimodal
_input
s
is
None
:
if
m
m
_input
is
None
:
# text only
mrope_positions
=
[
[
...
...
@@ -416,23 +415,58 @@ class ForwardBatch:
]
]
*
3
else
:
image_grid_thws_list
=
[
item
.
image_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
image_grid_thws
is
not
None
]
image_grid_thw
=
(
None
if
len
(
image_grid_thws_list
)
==
0
else
torch
.
cat
(
image_grid_thws_list
,
dim
=
0
)
)
video_grid_thws_list
=
[
item
.
video_grid_thws
for
item
in
mm_input
.
mm_items
if
item
.
video_grid_thws
is
not
None
]
video_grid_thw
=
(
None
if
len
(
video_grid_thws_list
)
==
0
else
torch
.
cat
(
video_grid_thws_list
,
dim
=
0
)
)
second_per_grid_ts_list
=
[
item
.
second_per_grid_ts
for
item
in
mm_input
.
mm_items
if
item
.
second_per_grid_ts
is
not
None
]
second_per_grid_ts
=
(
None
if
len
(
second_per_grid_ts_list
)
==
0
else
torch
.
cat
(
second_per_grid_ts_list
,
dim
=
0
)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
MRotaryEmbedding
.
get_input_positions
(
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
],
image_grid_thw
=
multimodal_inputs
.
image_grid_thw
s
,
video_grid_thw
=
multimodal_inputs
.
video_grid_thw
s
,
image_token_id
=
multimodal_inputs
.
im
_token_id
,
video_token_id
=
multimodal_inputs
.
video_token_id
,
]
.
tolist
()
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
image_token_id
=
hf_config
.
image
_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
seq_len
=
len
(
self
.
input_ids
),
second_per_grid_ts
=
multimodal_inputs
.
second_per_grid_ts
,
tokens_per_second
=
hf_config
.
vision_config
.
tokens_per_second
,
second_per_grid_ts
=
second_per_grid_ts
,
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
None
),
)
)
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
=
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e53a0b3d
...
...
@@ -1070,7 +1070,8 @@ class ModelRunner:
rope_scaling
=
getattr
(
self
.
model_config
.
hf_config
,
"rope_scaling"
,
{})
if
rope_scaling
is
None
:
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
is_mrope_enabled
=
"mrope_section"
in
rope_scaling
return
is_mrope_enabled
def
save_remote_model
(
self
,
url
:
str
):
from
sglang.srt.model_loader.loader
import
RemoteModelLoader
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
e53a0b3d
...
...
@@ -30,12 +30,16 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
Qwen2VLConfig
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2.modeling_qwen2
import
Qwen2RMSNorm
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
)
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
(
Qwen2_5_VisionPatchEmbed
,
Qwen2_5_VisionRotaryEmbedding
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.attention.vision
import
VisionAttention
...
...
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
return
x
class
Qwen2_5_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
in_chans
:
int
=
3
,
embed_dim
:
int
=
1152
,
)
->
None
:
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
embed_dim
=
embed_dim
kernel_size
=
[
temporal_patch_size
,
patch_size
,
patch_size
]
self
.
proj
=
nn
.
Conv3d
(
in_chans
,
embed_dim
,
kernel_size
=
kernel_size
,
stride
=
kernel_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
target_dtype
=
self
.
proj
.
weight
.
dtype
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
.
to
(
dtype
=
target_dtype
)).
view
(
L
,
self
.
embed_dim
)
return
x
class
Qwen2_5_VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
...
...
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return
out
class
Qwen2_5_VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
return
freqs
class
Qwen2_5_VisionTransformer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
spatial_merge_size
:
int
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
spatial_merge_unit
:
int
=
spatial_merge_size
*
spatial_merge_size
in_chans
:
int
=
vision_config
.
in_channels
in_chan
nel
s
:
int
=
vision_config
.
in_channels
hidden_size
:
int
=
vision_config
.
hidden_size
depth
:
int
=
vision_config
.
depth
num_heads
:
int
=
vision_config
.
num_heads
...
...
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self
.
patch_embed
=
Qwen2_5_VisionPatchEmbed
(
patch_size
=
patch_size
,
temporal_patch_size
=
temporal_patch_size
,
in_chans
=
in_chans
,
in_chan
nel
s
=
in_chan
nel
s
,
embed_dim
=
hidden_size
,
)
...
...
@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor)
class
Qwen2_5_VLForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2VLConfig
,
config
:
Qwen2
_5_
VLConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
...
...
@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
if
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
if
not
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
is_mrope_enabled
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
e53a0b3d
...
...
@@ -521,14 +521,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
is_mrope_enabled
=
"mrope_section"
in
self
.
config
.
rope_scaling
if
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
if
not
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
is_mrope_enabled
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
...
...
python/sglang/srt/openai_api/adapter.py
View file @
e53a0b3d
...
...
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
):
encoded
=
encoded
[
1
:]
prompt_ids
+=
encoded
if
tokenizer_manager
.
model_config
.
is_multimodal
:
prompt
=
tokenizer_manager
.
tokenizer
.
decode
(
prompt_ids
)
stop
=
request
.
stop
image_data
=
None
audio_data
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment