Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
94cde109
Unverified
Commit
94cde109
authored
Oct 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 21, 2024
Browse files
Llama3.2 vision model support (#1551)
parent
00611286
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1536 additions
and
118 deletions
+1536
-118
python/pyproject.toml
python/pyproject.toml
+28
-9
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-1
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+1
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+13
-0
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+11
-4
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+17
-5
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+121
-36
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+17
-5
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+72
-16
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+148
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+10
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+53
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+12
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-5
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1004
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+5
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/pyproject.toml
View file @
94cde109
...
...
@@ -8,16 +8,12 @@ version = "0.3.4"
description
=
"SGLang is yet another fast serving framework for large language models and vision language models."
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
]
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
,
]
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
]
[project.optional-dependencies]
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
...
...
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
,
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
...
...
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug
Tracker"
=
"https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
[tool.wheel]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
python/sglang/bench_latency.py
View file @
94cde109
...
...
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
...
...
python/sglang/lang/chat_template.py
View file @
94cde109
...
...
@@ -229,6 +229,7 @@ register_chat_template(
),
},
stop_str
=
(
"<|eot_id|>"
,),
image_token
=
"<|image|>"
,
)
)
...
...
python/sglang/srt/configs/model_config.py
View file @
94cde109
...
...
@@ -89,6 +89,8 @@ class ModelConfig:
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
is_encoder_decoder
=
self
.
hf_config
.
model_type
in
[
"mllama"
]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
python/sglang/srt/conversation.py
View file @
94cde109
...
...
@@ -509,6 +509,19 @@ register_conv_template(
)
)
register_conv_template
(
Conversation
(
name
=
"llama_3_vision"
,
system_message
=
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
,
system_template
=
"<|start_header_id|>system<|end_header_id|>
\n\n
{system_message}<|eot_id|>"
,
roles
=
(
"user"
,
"assistant"
),
sep_style
=
SeparatorStyle
.
LLAMA3
,
sep
=
""
,
stop_str
=
[
"<|end_of_text|>"
,
"<|eot_id|>"
],
image_token
=
"<|image|>"
,
)
)
register_conv_template
(
Conversation
(
name
=
"llava_llama_3"
,
...
...
python/sglang/srt/layers/attention/__init__.py
View file @
94cde109
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
from
torch
import
nn
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run forward on an attention layer."""
...
...
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for decode."""
...
...
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for extend."""
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
94cde109
...
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
...
...
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
...
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
...
...
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
...
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
94cde109
...
...
@@ -11,7 +11,6 @@ from enum import Enum, auto
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
...
...
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
if
is_flashinfer_available
():
...
...
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
and
model_runner
.
model_config
.
is_encoder_decoder
),
"Sliding window and cross attention are not supported together"
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
elif
model_runner
.
model_config
.
is_encoder_decoder
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
else
:
...
...
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
else
:
...
...
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
prefix_lens
,
use_ragged
,
use_ragged
=
use_ragged
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
cuda_graph_kv_indices
=
torch
.
zeros
(
...
...
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
]
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
...
...
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
(
decode_wrappers
,)
...
...
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
self
.
cuda_graph_metadata
[
bs
],
decode_wrappers
=
self
.
cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
self
.
_get_wrapper_idx
(
layer
)
]
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
True
,
causal
=
not
layer
.
is_cross_attention
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
layer
.
logit_cap
,
...
...
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
k
is
not
None
:
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
...
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
if
self
.
num_wrappers
==
1
:
return
0
...
...
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
...
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
self
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
...
...
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
...
...
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
update_cross_attention
(
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# Normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# Cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
seq_lens_sum
=
encoder_lens
.
sum
().
item
()
self
.
call_begin_forward
(
decode_wrappers
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens_sum
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
)
def
call_begin_forward
(
self
,
...
...
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
...
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
update_cross_attention
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
)
def
call_begin_forward
(
self
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
94cde109
...
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
...
...
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
...
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
...
...
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
...
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
...
...
python/sglang/srt/managers/image_processor.py
View file @
94cde109
...
...
@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
class
BaseImageProcessor
(
ABC
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
@
abstractmethod
async
def
process_images_async
(
self
,
image_data
,
**
kwargs
):
async
def
process_images_async
(
self
,
image_data
,
input_text
,
**
kwargs
):
pass
class
DummyImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
):
pass
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
return
None
class
LlavaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
self
.
hf_config
=
hf_config
self
.
_image_processor
=
_image_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
...
...
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
if
not
image_data
:
return
None
...
...
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
}
class
MllamaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
global_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
if
isinstance
(
input_text
,
list
):
assert
len
(
input_text
)
and
isinstance
(
input_text
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_text
)
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
if
len
(
image_data
)
>
0
:
images
=
[
load_image
(
image
)[
0
]
for
image
in
image_data
]
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"image_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
class
Qwen2VLImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
self
.
hf_config
=
hf_config
...
...
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return
self
.
_process_single_image_task
(
image_data
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
if
not
image_data
:
return
None
...
...
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def
get_image_processor
(
hf_config
,
server_args
:
ServerArgs
,
_image_
processor
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseImageProcessor
:
if
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
_image_processor
)
if
"MllamaForConditionalGeneration"
in
hf_config
.
architectures
:
return
MllamaImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
else
:
return
LlavaImageProcessor
(
hf_config
,
server_args
,
_
image_processor
)
return
LlavaImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
def
get_dummy_image_processor
():
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
94cde109
...
...
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import
torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
...
@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs."""
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_hash
es
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
...
@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
image_hash
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
image_grid_thws
=
obj
.
get
(
"image_grid_thws"
),
image_hashes
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
)
image_hash
=
ret
.
image_hash
image_hash
=
ret
.
image_hash
es
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
ret
.
image_sizes
=
obj
[
"image_sizes"
]
# Only when pixel values is not None we have modalities
ret
.
modalities
=
obj
[
"modalities"
]
or
[
"image"
]
optional_args
=
[
"image_sizes"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
return
ret
...
...
@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
# For utility
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens
:
int
=
None
decoding_reqs
:
List
[
Req
]
=
None
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Stream
has_stream
:
bool
=
False
...
...
@@ -450,12 +470,20 @@ class ScheduleBatch:
device
:
str
=
"cuda"
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
,
model_config
,
):
return
cls
(
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
...
...
@@ -493,7 +521,78 @@ class ScheduleBatch:
return
out_cache_loc
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
for
req
in
self
.
reqs
:
im
=
req
.
image_inputs
if
im
is
None
or
im
.
num_image_tokens
is
None
:
# No image input
self
.
encoder_lens_cpu
.
append
(
0
)
self
.
encoder_cached
.
append
(
True
)
else
:
self
.
encoder_lens_cpu
.
append
(
im
.
num_image_tokens
)
self
.
encoder_cached
.
append
(
self
.
forward_mode
.
is_decode
()
or
len
(
req
.
prefix_indices
)
>=
im
.
num_image_tokens
)
self
.
encoder_lens
=
torch
.
tensor
(
self
.
encoder_lens_cpu
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
# Strip encoder infos
pt
=
0
decoder_out_cache_loc
=
[]
encoder_out_cache_loc
=
[]
for
i
,
req
in
enumerate
(
self
.
reqs
):
encoder_len
=
self
.
encoder_lens_cpu
[
i
]
seq_lens
[
i
]
-=
encoder_len
if
len
(
req
.
prefix_indices
)
<
encoder_len
:
# NOTE: the encoder part should considered as a whole
assert
len
(
req
.
prefix_indices
)
==
0
input_ids
[
i
]
=
input_ids
[
i
][
encoder_len
:]
encoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
encoder_len
])
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
+
encoder_len
:
pt
+
req
.
extend_input_len
]
)
self
.
extend_lens
[
i
]
-=
encoder_len
self
.
extend_num_tokens
-=
encoder_len
else
:
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
)
self
.
prefix_lens
[
i
]
-=
encoder_len
pt
+=
req
.
extend_input_len
# Reassign
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
if
not
decoder_out_cache_loc
:
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
out_cache_loc
=
torch
.
cat
(
decoder_out_cache_loc
)
if
not
encoder_out_cache_loc
:
self
.
encoder_out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
encoder_out_cache_loc
=
torch
.
cat
(
encoder_out_cache_loc
)
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
...
...
@@ -561,8 +660,13 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
if
self
.
model_config
.
is_encoder_decoder
:
self
.
prepare_encoder_info_extend
(
input_ids
,
seq_lens
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
self
,
self
.
model_config
.
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
],
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
...
@@ -752,6 +856,10 @@ class ScheduleBatch:
return
jump_forward_reqs
def
prepare_encoder_info_decode
(
self
):
# Reset the encoder cached status
self
.
encoder_cached
=
[
True
]
*
len
(
self
.
reqs
)
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
...
...
@@ -766,16 +874,22 @@ class ScheduleBatch:
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
if
self
.
model_config
.
is_encoder_decoder
:
locs
=
self
.
encoder_lens
+
self
.
seq_lens
self
.
prepare_encoder_info_decode
()
else
:
locs
=
self
.
seq_lens
if
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
# A faster in-place version
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
...
...
@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter
return
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
self
.
encoder_lens
[
keep_indices
]
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
new_indices
=
torch
.
tensor
(
keep_indices
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
...
...
@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
self
.
sampling_info
.
merge_batch
(
other
.
sampling_info
)
# Encoder-decoder infos
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
torch
.
cat
([
self
.
encoder_lens
,
other
.
encoder_lens
])
self
.
encoder_lens_cpu
.
extend
(
other
.
encoder_lens_cpu
)
self
.
req_pool_indices
=
torch
.
concat
(
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
...
...
@@ -850,14 +973,11 @@ class ScheduleBatch:
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
(
image_inputs
)
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
extend_seq_lens
=
self
.
extend_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
...
...
@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
],
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
self
.
encoder_out_cache_loc
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
...
...
@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
model_config
=
self
.
model_config
,
forward_mode
=
self
.
forward_mode
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
...
...
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
encoder_lens
:
Optional
[
torch
.
Tensor
]
encoder_lens_cpu
:
Optional
[
List
[
int
]]
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
...
...
python/sglang/srt/managers/scheduler.py
View file @
94cde109
...
...
@@ -662,8 +662,9 @@ class Scheduler:
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
)
new_batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
94cde109
...
...
@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self
.
image_processor
=
get_image_processor
(
self
.
hf_config
,
server_args
,
self
.
processor
.
image_processor
self
.
hf_config
,
server_args
,
self
.
processor
)
else
:
self
.
tokenizer
=
get_tokenizer
(
...
...
@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
...
...
@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
obj
obj
.
image_data
[
index
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
...
...
@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
0
],
obj
obj
.
image_data
[
0
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
94cde109
...
...
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
self
.
free_slots
=
list
(
range
(
size
))
self
.
write_records
=
[]
self
.
use_records
=
use_records
if
use_records
:
# records all write operations
if
self
.
use_records
:
self
.
write
=
self
.
write_with_records
else
:
self
.
write
=
self
.
write_without_records
def
write
(
self
,
indices
,
values
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
...
...
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
...
...
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
cache_v
.
dtype
!=
self
.
dtype
:
...
...
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
...
...
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
# NOTE(Andy): ignore the dtype check
layer_id
=
layer
.
layer_id
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
94cde109
...
...
@@ -105,6 +105,7 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
None
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
...
@@ -132,6 +133,9 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
if
self
.
use_torch_compile
:
set_torch_compile_config
()
...
...
@@ -144,9 +148,18 @@ class CudaGraphRunner:
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
encoder_len_fill_value
,
dtype
=
torch
.
int32
)
else
:
self
.
encoder_lens
=
None
# Capture
try
:
self
.
capture
()
with
self
.
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
...
...
@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
def
can_run
(
self
,
batch_size
:
int
):
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
else
:
return
batch_size
<=
self
.
max_bs
@
contextmanager
def
model_capture_mode
(
self
):
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
True
yield
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported
=
(
torch
.
all
(
forward_batch
.
encoder_lens
>
0
)
if
self
.
is_encoder_decoder
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
...
...
@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
bs
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
)
# Run and capture
...
...
@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
...
...
@@ -251,6 +294,8 @@ class CudaGraphRunner:
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -258,6 +303,7 @@ class CudaGraphRunner:
self
.
req_pool_indices
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
self
.
encoder_lens
,
)
# Replay
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
94cde109
...
...
@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
=
None
# Encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
...
...
@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
image_inputs
=
batch
.
image_inputs
,
encoder_cached
=
batch
.
encoder_cached
,
encoder_lens
=
batch
.
encoder_lens
,
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
...
...
@@ -212,11 +223,11 @@ class ForwardBatch:
],
axis
=
0
,
)
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
94cde109
...
...
@@ -270,7 +270,6 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
)
self
.
has_cross_attention
=
getattr
(
self
.
model
,
"has_cross_attention"
,
False
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
...
...
@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
has_cross_attention
,
(
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
...
...
@@ -558,9 +557,7 @@ class ModelRunner:
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
.
batch_size
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
...
...
python/sglang/srt/models/mllama.py
0 → 100644
View file @
94cde109
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/qwen2_vl.py
View file @
94cde109
...
...
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
]
positions
=
forward_batch
.
mrope_positions
if
image_inputs
is
None
or
len
(
image_inputs
)
==
0
:
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
len
(
image_inputs
)
==
0
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
...
...
python/sglang/srt/utils.py
View file @
94cde109
...
...
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
):
return
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment