Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
def55bc8
"benchmark/vscode:/vscode.git/clone" did not exist on "f1b779fc6233991af4fb029aee18bebefc18cefd"
Unverified
Commit
def55bc8
authored
Oct 25, 2024
by
yizhang2077
Committed by
GitHub
Oct 25, 2024
Browse files
Qwen2vl support cuda graph and disable radix cache (#1780)
parent
86a2c473
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
60 deletions
+29
-60
README.md
README.md
+1
-1
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+15
-48
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+5
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+6
-9
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
No files found.
README.md
View file @
def55bc8
...
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
...
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
-
Llama / Llama 2 / Llama 3 / Llama 3.1
-
Llama / Llama 2 / Llama 3 / Llama 3.1
-
Mistral / Mixtral / Mistral NeMo
-
Mistral / Mixtral / Mistral NeMo
-
Gemma / Gemma 2
-
Gemma / Gemma 2
-
Qwen / Qwen 2 / Qwen 2 MoE
-
Qwen / Qwen 2 / Qwen 2 MoE
/ Qwen 2 VL
-
DeepSeek / DeepSeek 2
-
DeepSeek / DeepSeek 2
-
OLMoE
-
OLMoE
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
-
[
LLaVA-OneVision
](
https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
)
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
def55bc8
...
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
...
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
@
staticmethod
@
staticmethod
def
get_input_positions
(
def
get_input_positions
(
input_tokens
:
List
[
int
]
,
input_tokens
:
torch
.
Tensor
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
context_len
:
int
=
0
,
extend_prefix_len
:
int
=
0
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
vision_start_indices
=
torch
.
argwhere
(
input_tokens
_tensor
==
vision_start_token_id
input_tokens
==
vision_start_token_id
).
squeeze
(
1
)
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_indices
=
vision_start_indices
+
1
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
image_nums
=
image_indices
.
shape
[
0
]
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
llm_pos_ids_list
:
list
=
[]
st
=
0
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
input_tokens_len
=
input_tokens
.
shape
[
0
]
for
image_index
in
range
(
image_nums
):
image_index
,
video_index
=
0
,
0
ed
=
image_indices
[
image_index
].
item
()
for
_
in
range
(
image_nums
+
video_nums
):
t
,
h
,
w
=
(
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
image_grid_thw
[
image_index
][
0
],
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
image_grid_thw
[
image_index
][
1
],
else
:
image_grid_thw
[
image_index
][
2
],
ed_image
=
len
(
input_tokens
)
+
1
)
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
t
,
h
//
spatial_merge_size
,
h
//
spatial_merge_size
,
...
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
...
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
)
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
)
:
if
st
<
input_tokens
_len
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
text_len
=
input_tokens
_len
-
st
llm_pos_ids_list
.
append
(
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:]
llm_positions
=
llm_positions
[:,
context_len
:]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
input_tokens_len
).
item
()
llm_positions
+=
extend_prefix_len
return
llm_positions
.
tolist
(),
mrope_position_delta
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
@
staticmethod
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
def55bc8
...
@@ -152,6 +152,7 @@ class CudaGraphRunner:
...
@@ -152,6 +152,7 @@ class CudaGraphRunner:
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
...
@@ -233,6 +234,7 @@ class CudaGraphRunner:
...
@@ -233,6 +234,7 @@ class CudaGraphRunner:
encoder_lens
=
None
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
...
@@ -259,6 +261,7 @@ class CudaGraphRunner:
...
@@ -259,6 +261,7 @@ class CudaGraphRunner:
return_logprob
=
False
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
)
)
return
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
...
@@ -301,6 +304,8 @@ class CudaGraphRunner:
...
@@ -301,6 +304,8 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
def55bc8
...
@@ -142,11 +142,12 @@ class ForwardBatch:
...
@@ -142,11 +142,12 @@ class ForwardBatch:
int
(
self
.
seq_lens
[
i
]),
int
(
self
.
seq_lens
[
i
]),
)
)
elif
self
.
forward_mode
.
is_extend
():
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
image_inputs
in
enumerate
(
batch
.
image_inputs
):
for
i
,
image_inputs
in
enumerate
(
batch
.
image_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
self
.
extend_start_loc
[
i
],
extend_start_loc
_cpu
[
i
],
self
.
extend_seq_lens
[
i
],
batch
.
extend_seq_lens
[
i
],
self
.
extend_prefix_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
)
if
image_inputs
is
None
:
if
image_inputs
is
None
:
# text only
# text only
...
@@ -160,20 +161,16 @@ class ForwardBatch:
...
@@ -160,20 +161,16 @@ class ForwardBatch:
]
*
3
]
*
3
mrope_position_delta
=
0
mrope_position_delta
=
0
else
:
else
:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions
,
mrope_position_delta
=
(
mrope_positions
,
mrope_position_delta
=
(
MRotaryEmbedding
.
get_input_positions
(
MRotaryEmbedding
.
get_input_positions
(
input_tokens
=
self
.
input_ids
[
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
extend_start_loc
:
extend_start_loc
+
extend_seq_len
]
.
tolist
()
,
],
image_grid_thw
=
image_inputs
.
image_grid_thws
,
image_grid_thw
=
image_inputs
.
image_grid_thws
,
video_grid_thw
=
None
,
image_token_id
=
hf_config
.
image_token_id
,
video_token_id
=
hf_config
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
context_len
=
0
,
extend_prefix_len
=
extend_prefix_len
.
item
(),
)
)
)
)
mrope_positions_list
[
i
]
=
mrope_positions
mrope_positions_list
[
i
]
=
mrope_positions
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
def55bc8
...
@@ -125,11 +125,11 @@ class ModelRunner:
...
@@ -125,11 +125,11 @@ class ModelRunner:
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
server_args
.
mem_fraction_static
*=
0.95
# TODO: qwen2-vl does not support
cuda graph
now, set disable
-graph
=True automatically
# TODO: qwen2-vl does not support
radix cache
now, set disable
_radix_cache
=True automatically
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
"Qwen2VLForConditionalGeneration"
]:
]:
server_args
.
disable_
cuda_graph
=
True
server_args
.
disable_
radix_cache
=
True
# Global vars
# Global vars
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment