Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d373a48c
Unverified
Commit
d373a48c
authored
Mar 18, 2025
by
Mick
Committed by
GitHub
Mar 17, 2025
Browse files
fix: second_per_grid_ts should be used to get mrope position (#3682)
parent
98be3bd3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
69 deletions
+93
-69
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+20
-6
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+7
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-2
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+31
-49
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+25
-8
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+1
-3
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
d373a48c
...
...
@@ -880,8 +880,17 @@ class MRotaryEmbedding(RotaryEmbedding):
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
second_per_grid_ts
:
Optional
[
torch
.
Tensor
]
=
None
,
tokens_per_second
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
"""
Get mrope input positions and delta value.
:arg
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
"""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
...
...
@@ -918,6 +927,7 @@ class MRotaryEmbedding(RotaryEmbedding):
)
image_index
+=
1
remain_images
-=
1
second_per_grid_t
=
0
ed
=
ed_image
else
:
t
,
h
,
w
=
(
...
...
@@ -925,6 +935,10 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
if
second_per_grid_ts
is
not
None
:
second_per_grid_t
=
second_per_grid_ts
[
video_index
]
else
:
second_per_grid_t
=
1.0
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
...
...
@@ -941,11 +955,11 @@ class MRotaryEmbedding(RotaryEmbedding):
)
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
flatten
()
)
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
second_per_grid_t
*
tokens_per_second
)
.
flatten
()
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
d373a48c
...
...
@@ -159,6 +159,10 @@ class ImageInputs:
# QWen2-VL related
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# Qwen2-VL video related
video_token_id
:
Optional
[
int
]
=
None
video_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# deepseek vl2 related
image_seq_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
d373a48c
...
...
@@ -402,9 +402,16 @@ class ForwardBatch:
extend_start_loc
:
extend_start_loc
+
extend_seq_len
],
image_grid_thw
=
image_inputs
.
image_grid_thws
,
video_grid_thw
=
image_inputs
.
video_grid_thws
,
image_token_id
=
image_inputs
.
im_token_id
,
video_token_id
=
image_inputs
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
seq_len
=
len
(
self
.
input_ids
),
second_per_grid_ts
=
image_inputs
.
second_per_grid_ts
,
tokens_per_second
=
hf_config
.
vision_config
.
tokens_per_second
,
)
)
batch
.
image_inputs
[
i
].
mrope_position_delta
=
mrope_position_delta
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
d373a48c
...
...
@@ -258,10 +258,12 @@ class ModelRunner:
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
]
or
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
# TODO: qwen2-vl
series
does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen
2
-vl."
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl
series
."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
d373a48c
...
...
@@ -125,12 +125,15 @@ class Qwen2_5_VisionBlock(nn.Module):
if
attn_implementation
==
"sdpa"
:
use_context_forward
=
False
softmax_in_single_precision
=
False
flatten_batch
=
True
elif
attn_implementation
==
"flash_attention_2"
:
softmax_in_single_precision
=
False
use_context_forward
=
True
flatten_batch
=
True
elif
attn_implementation
==
"eager"
:
softmax_in_single_precision
=
True
use_context_forward
=
False
flatten_batch
=
True
self
.
attn
=
VisionAttention
(
embed_dim
=
dim
,
...
...
@@ -139,7 +142,7 @@ class Qwen2_5_VisionBlock(nn.Module):
use_qkv_parallel
=
False
,
use_context_forward
=
use_context_forward
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
True
,
flatten_batch
=
flatten_batch
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
...
...
@@ -192,9 +195,10 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
target_dtype
=
self
.
proj
.
weight
.
dtype
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
embed_dim
)
x
=
self
.
proj
(
x
.
to
(
dtype
=
target_dtype
)
).
view
(
L
,
self
.
embed_dim
)
return
x
...
...
@@ -246,35 +250,15 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
)
/
self
.
dim
)
)
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
return
freqs
class
Qwen2_5_VisionTransformer
(
nn
.
Module
):
...
...
@@ -293,7 +277,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
spatial_merge_size
:
int
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
spatial_merge_unit
:
int
=
spatial_merge_size
*
spatial_merge_size
in_chans
:
int
=
vision_config
.
in_chans
in_chans
:
int
=
vision_config
.
in_chan
nel
s
hidden_size
:
int
=
vision_config
.
hidden_size
depth
:
int
=
vision_config
.
depth
num_heads
:
int
=
vision_config
.
num_heads
...
...
@@ -393,27 +377,24 @@ class Qwen2_5_VisionTransformer(nn.Module):
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
hpos_ids
=
(
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
.
permute
(
0
,
2
,
1
,
3
)
.
flatten
()
hpos_ids
=
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
w
pos_ids
=
(
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
.
permute
(
0
,
2
,
1
,
3
)
.
flatten
()
h
pos_ids
=
hpos_ids
.
permute
(
0
,
2
,
1
,
3
)
hpos_ids
=
hpos_ids
.
flatten
()
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
wpos_ids
=
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
wpos_ids
=
wpos_ids
.
permute
(
0
,
2
,
1
,
3
)
wpos_ids
=
wpos_ids
.
flatten
()
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
...
...
@@ -437,7 +418,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens
=
torch
.
tensor
(
cu_window_seqlens
,
device
=
x
.
device
,
dtype
=
grid_thw
.
dtype
if
torch
.
jit
.
is_tracing
()
else
torch
.
int32
,
dtype
=
torch
.
int32
,
)
cu_window_seqlens
=
torch
.
unique_consecutive
(
cu_window_seqlens
)
...
...
@@ -610,7 +591,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
pixel_values
=
image
.
pixel_values
.
clone
().
detach
().
requires_grad_
(
False
)
pixel_values
=
image
.
pixel_values
.
to
(
device
=
"cuda"
)
image_grid_thws
=
torch
.
tensor
(
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
)
...
...
test/srt/run_suite.py
View file @
d373a48c
...
...
@@ -68,7 +68,7 @@ suites = {
TestFile
(
"test_update_weights_from_tensor.py"
,
48
),
TestFile
(
"test_vertex_endpoint.py"
,
31
),
TestFile
(
"test_vision_chunked_prefill.py"
,
223
),
TestFile
(
"test_v
ision_llm
.py"
,
18.4
),
TestFile
(
"test_v
lm_accuracy
.py"
,
60
),
TestFile
(
"test_vision_openai_server.py"
,
344
),
TestFile
(
"test_fim_completion.py"
,
120
),
TestFile
(
"test_w8a8_quantization.py"
,
46
),
...
...
test/srt/test_vision_openai_server.py
View file @
d373a48c
...
...
@@ -191,7 +191,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
# from transformers import AutoTokenizer
from
decord
import
VideoReader
,
cpu
max_frames_num
=
1
2
max_frames_num
=
2
0
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
total_frame_num
=
len
(
vr
)
uniform_sampled_frames
=
np
.
linspace
(
...
...
@@ -226,6 +226,22 @@ class TestOpenAIVisionServer(unittest.TestCase):
return
messages
def
prepare_video_messages_video_direct
(
self
,
video_path
):
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"video:
{
video_path
}
"
},
"modalities"
:
"video"
,
},
{
"type"
:
"text"
,
"text"
:
"Please describe the video in detail."
},
],
},
]
return
messages
def
test_video_chat_completion
(
self
):
url
=
"https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache"
)
...
...
@@ -241,6 +257,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
# messages = self.prepare_video_messages_video_direct(file_path)
messages
=
self
.
prepare_video_messages
(
file_path
)
video_request
=
client
.
chat
.
completions
.
create
(
...
...
@@ -266,6 +283,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"man"
in
video_response
or
"person"
in
video_response
or
"individual"
in
video_response
or
"speaker"
in
video_response
),
video_response
assert
(
"present"
in
video_response
...
...
@@ -368,7 +386,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
list
(
executor
.
map
(
self
.
run_decode_with_image
,
image_ids
))
class
TestQ
W
en2VLServer
(
TestOpenAIVisionServer
):
class
TestQ
w
en2VLServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen2-VL-7B-Instruct"
...
...
@@ -382,14 +400,14 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
"--
chunked-prefill-size
"
,
"
10000
"
,
"--
mem-fraction-static
"
,
"
0.4
"
,
],
)
cls
.
base_url
+=
"/v1"
class
TestQ
W
en2_5_VLServer
(
TestOpenAIVisionServer
):
class
TestQ
w
en2_5_VLServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen2.5-VL-7B-Instruct"
...
...
@@ -403,9 +421,6 @@ class TestQWen2_5_VLServer(TestOpenAIVisionServer):
other_args
=
[
"--chat-template"
,
"qwen2-vl"
,
# FIXME: workaround to chunked prefill within image embeds
"--chunked-prefill-size"
,
"10000"
,
"--mem-fraction-static"
,
"0.4"
,
],
...
...
@@ -508,6 +523,8 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
"--trust-remote-code"
,
"--chat-template"
,
"minicpmv"
,
"--mem-fraction-static"
,
"0.4"
,
],
)
cls
.
base_url
+=
"/v1"
...
...
test/srt/test_v
ision_llm
.py
→
test/srt/test_v
lm_accuracy
.py
View file @
d373a48c
...
...
@@ -17,8 +17,6 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
MiniCPMV
=
"openbmb/MiniCPM-V-2_6"
# Test the logits output between HF and SGLang
class
VisionLLMLogitsBase
(
unittest
.
IsolatedAsyncioTestCase
):
...
...
@@ -155,7 +153,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model_path
=
MiniCPMV
cls
.
model_path
=
"openbmb/MiniCPM-V-2_6"
cls
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment