Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
def55bc8
Unverified
Commit
def55bc8
authored
Oct 25, 2024
by
yizhang2077
Committed by
GitHub
Oct 25, 2024
Browse files
Qwen2vl support cuda graph and disable radix cache (#1780)
parent
86a2c473
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
60 deletions
+29
-60
README.md
README.md
+1
-1
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+15
-48
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+5
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+6
-9
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
No files found.
README.md
View file @
def55bc8
...
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
...
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
-
Llama / Llama 2 / Llama 3 / Llama 3.1
-
Llama / Llama 2 / Llama 3 / Llama 3.1
-
Mistral / Mixtral / Mistral NeMo
-
Mistral / Mixtral / Mistral NeMo
-
Gemma / Gemma 2
-
Gemma / Gemma 2
-
Qwen / Qwen 2 / Qwen 2 MoE
-
Qwen / Qwen 2 / Qwen 2 MoE
/ Qwen 2 VL
-
DeepSeek / DeepSeek 2
-
DeepSeek / DeepSeek 2
-
OLMoE
-
OLMoE
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
def55bc8
...
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
...
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
@
staticmethod
@
staticmethod
def
get_input_positions
(
def
get_input_positions
(
input_tokens
:
List
[
int
]
,
input_tokens
:
torch
.
Tensor
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
context_len
:
int
=
0
,
extend_prefix_len
:
int
=
0
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
vision_start_indices
=
torch
.
argwhere
(
input_tokens
_tensor
==
vision_start_token_id
input_tokens
==
vision_start_token_id
).
squeeze
(
1
)
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_indices
=
vision_start_indices
+
1
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
image_nums
=
image_indices
.
shape
[
0
]
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
llm_pos_ids_list
:
list
=
[]
st
=
0
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
input_tokens_len
=
input_tokens
.
shape
[
0
]
for
image_index
in
range
(
image_nums
):
image_index
,
video_index
=
0
,
0
ed
=
image_indices
[
image_index
].
item
()
for
_
in
range
(
image_nums
+
video_nums
):
t
,
h
,
w
=
(
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
image_grid_thw
[
image_index
][
0
],
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
image_grid_thw
[
image_index
][
1
],
else
:
image_grid_thw
[
image_index
][
2
],
ed_image
=
len
(
input_tokens
)
+
1
)
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
t
,
h
//
spatial_merge_size
,
h
//
spatial_merge_size
,
...
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
...
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
)
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
)
:
if
st
<
input_tokens
_len
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
text_len
=
input_tokens
_len
-
st
llm_pos_ids_list
.
append
(
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:]
llm_positions
=
llm_positions
[:,
context_len
:]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
input_tokens_len
).
item
()
llm_positions
+=
extend_prefix_len
return
llm_positions
.
tolist
(),
mrope_position_delta
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
@
staticmethod
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
def55bc8
...
@@ -152,6 +152,7 @@ class CudaGraphRunner:
...
@@ -152,6 +152,7 @@ class CudaGraphRunner:
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
...
@@ -233,6 +234,7 @@ class CudaGraphRunner:
...
@@ -233,6 +234,7 @@ class CudaGraphRunner:
encoder_lens
=
None
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
...
@@ -259,6 +261,7 @@ class CudaGraphRunner:
...
@@ -259,6 +261,7 @@ class CudaGraphRunner:
return_logprob
=
False
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
)
)
return
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
...
@@ -301,6 +304,8 @@ class CudaGraphRunner:
...
@@ -301,6 +304,8 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
def55bc8
...
@@ -142,11 +142,12 @@ class ForwardBatch:
...
@@ -142,11 +142,12 @@ class ForwardBatch:
int
(
self
.
seq_lens
[
i
]),
int
(
self
.
seq_lens
[
i
]),
)
)
elif
self
.
forward_mode
.
is_extend
():
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
image_inputs
in
enumerate
(
batch
.
image_inputs
):
for
i
,
image_inputs
in
enumerate
(
batch
.
image_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
self
.
extend_start_loc
[
i
],
extend_start_loc
_cpu
[
i
],
self
.
extend_seq_lens
[
i
],
batch
.
extend_seq_lens
[
i
],
self
.
extend_prefix_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
)
if
image_inputs
is
None
:
if
image_inputs
is
None
:
# text only
# text only
...
@@ -160,20 +161,16 @@ class ForwardBatch:
...
@@ -160,20 +161,16 @@ class ForwardBatch:
]
*
3
]
*
3
mrope_position_delta
=
0
mrope_position_delta
=
0
else
:
else
:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
mrope_positions
,
mrope_position_delta
=
(
MRotaryEmbedding
.
get_input_positions
(
MRotaryEmbedding
.
get_input_positions
(
input_tokens
=
self
.
input_ids
[
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
extend_start_loc
:
extend_start_loc
+
extend_seq_len
]
.
tolist
()
,
],
image_grid_thw
=
image_inputs
.
image_grid_thws
,
image_grid_thw
=
image_inputs
.
image_grid_thws
,
video_grid_thw
=
None
,
image_token_id
=
hf_config
.
image_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
context_len
=
0
,
extend_prefix_len
=
extend_prefix_len
.
item
(),
)
)
)
)
mrope_positions_list
[
i
]
=
mrope_positions
mrope_positions_list
[
i
]
=
mrope_positions
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
def55bc8
...
@@ -125,11 +125,11 @@ class ModelRunner:
...
@@ -125,11 +125,11 @@ class ModelRunner:
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
server_args
.
mem_fraction_static
*=
0.95
# TODO: qwen2-vl does not support
cuda graph
now, set disable
-graph
=True automatically
# TODO: qwen2-vl does not support
radix cache
now, set disable
_radix_cache
=True automatically
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
"Qwen2VLForConditionalGeneration"
]:
]:
server_args
.
disable_
cuda_graph
=
True
server_args
.
disable_
radix_cache
=
True
# Global vars
# Global vars
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment