Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
94cde109
Unverified
Commit
94cde109
authored
Oct 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 21, 2024
Browse files
Llama3.2 vision model support (#1551)
parent
00611286
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1536 additions
and
118 deletions
+1536
-118
python/pyproject.toml
python/pyproject.toml
+28
-9
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-1
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+1
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+13
-0
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+11
-4
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+17
-5
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+121
-36
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+17
-5
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+72
-16
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+148
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+10
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+53
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+12
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-5
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1004
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+5
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/pyproject.toml
View file @
94cde109
...
@@ -8,16 +8,12 @@ version = "0.3.4"
...
@@ -8,16 +8,12 @@ version = "0.3.4"
description
=
"SGLang is yet another fast serving framework for large language models and vision language models."
description
=
"SGLang is yet another fast serving framework for large language models and vision language models."
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
"License :: OSI Approved :: Apache Software License"
,
]
]
dependencies
=
[
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
]
"requests"
,
"tqdm"
,
"numpy"
,
]
[project.optional-dependencies]
[project.optional-dependencies]
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
...
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
...
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
,
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
...
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
...
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug
Tracker"
=
"https://github.com/sgl-project/sglang/issues"
"Bug
Tracker"
=
"https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
[tool.setuptools.packages.find]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
[tool.wheel]
[tool.wheel]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
python/sglang/bench_latency.py
View file @
94cde109
...
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
...
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
...
...
python/sglang/lang/chat_template.py
View file @
94cde109
...
@@ -229,6 +229,7 @@ register_chat_template(
...
@@ -229,6 +229,7 @@ register_chat_template(
),
),
},
},
stop_str
=
(
"<|eot_id|>"
,),
stop_str
=
(
"<|eot_id|>"
,),
image_token
=
"<|image|>"
,
)
)
)
)
...
...
python/sglang/srt/configs/model_config.py
View file @
94cde109
...
@@ -89,6 +89,8 @@ class ModelConfig:
...
@@ -89,6 +89,8 @@ class ModelConfig:
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
is_encoder_decoder
=
self
.
hf_config
.
model_type
in
[
"mllama"
]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
"""Returns the total number of KV heads."""
...
...
python/sglang/srt/conversation.py
View file @
94cde109
...
@@ -509,6 +509,19 @@ register_conv_template(
...
@@ -509,6 +509,19 @@ register_conv_template(
)
)
)
)
register_conv_template
(
Conversation
(
name
=
"llama_3_vision"
,
system_message
=
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
,
system_template
=
"<|start_header_id|>system<|end_header_id|>
\n\n
{system_message}<|eot_id|>"
,
roles
=
(
"user"
,
"assistant"
),
sep_style
=
SeparatorStyle
.
LLAMA3
,
sep
=
""
,
stop_str
=
[
"<|end_of_text|>"
,
"<|eot_id|>"
],
image_token
=
"<|image|>"
,
)
)
register_conv_template
(
register_conv_template
(
Conversation
(
Conversation
(
name
=
"llava_llama_3"
,
name
=
"llava_llama_3"
,
...
...
python/sglang/srt/layers/attention/__init__.py
View file @
94cde109
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
...
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
...
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""Init the metadata for a forward pass for replying a cuda graph."""
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
...
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
"""Run forward on an attention layer."""
"""Run forward on an attention layer."""
...
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
...
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
"""Run a forward for decode."""
"""Run a forward for decode."""
...
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
...
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
"""Run a forward for extend."""
"""Run a forward for extend."""
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
94cde109
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_attn_logits
,
...
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
)
(
(
...
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
return
o
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
94cde109
...
@@ -11,7 +11,6 @@ from enum import Enum, auto
...
@@ -11,7 +11,6 @@ from enum import Enum, auto
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn
as
nn
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
if
is_flashinfer_available
():
if
is_flashinfer_available
():
...
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert
not
(
assert
not
(
model_runner
.
sliding_window_size
is
not
None
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
and
model_runner
.
model_config
.
is_encoder_decoder
),
"Sliding window and cross attention are not supported together"
),
"Sliding window and cross attention are not supported together"
if
model_runner
.
sliding_window_size
is
not
None
:
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
elif
model_runner
.
model_config
.
is_encoder_decoder
:
self
.
num_wrappers
=
2
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
else
:
else
:
...
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
else
:
else
:
...
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
prefix_lens
,
prefix_lens
,
use_ragged
,
use_ragged
=
use_ragged
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
)
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
)
use_ragged
,
extend_no_prefix
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
cuda_graph_kv_indices
=
torch
.
zeros
(
cuda_graph_kv_indices
=
torch
.
zeros
(
...
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
]
]
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
):
decode_wrappers
=
[]
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
for
i
in
range
(
self
.
num_wrappers
):
...
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
)
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
(
decode_wrappers
,)
self
.
forward_metadata
=
(
decode_wrappers
,)
...
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
):
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
seq_lens_sum
,
self
.
cuda_graph_metadata
[
bs
],
decode_wrappers
=
self
.
cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
)
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
self
.
_get_wrapper_idx
(
layer
)
self
.
_get_wrapper_idx
(
layer
)
]
]
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
use_ragged
:
if
not
use_ragged
:
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
o
=
prefill_wrapper_paged
.
forward
(
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
True
,
causal
=
not
layer
.
is_cross_attention
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
layer
.
logit_cap
,
logits_soft_cap
=
layer
.
logit_cap
,
...
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
o
=
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
if
self
.
num_wrappers
==
1
:
if
self
.
num_wrappers
==
1
:
return
0
return
0
...
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
# Dispatch
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
self
.
update
=
self
.
update_cross_attention
else
:
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
def
update_single_wrapper
(
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
self
.
call_begin_forward
(
...
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
...
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp
,
kv_start_idx_tmp
,
)
)
def
update_cross_attention
(
self
):
def
update_cross_attention
(
raise
NotImplementedError
()
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# Normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# Cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
seq_lens_sum
=
encoder_lens
.
sum
().
item
()
self
.
call_begin_forward
(
decode_wrappers
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens_sum
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
)
def
call_begin_forward
(
def
call_begin_forward
(
self
,
self
,
...
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
# Dispatch
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
self
.
update
=
self
.
update_cross_attention
else
:
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
):
if
use_ragged
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
...
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
)
)
def
update_sliding_window
(
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged
,
use_ragged
,
)
)
def
update_cross_attention
(
self
):
def
update_cross_attention
(
raise
NotImplementedError
()
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
)
def
call_begin_forward
(
def
call_begin_forward
(
self
,
self
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
94cde109
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
)
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_attn_logits
,
...
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
...
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
)
)
return
o
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
self
.
decode_attention_fwd
(
self
.
decode_attention_fwd
(
...
...
python/sglang/srt/managers/image_processor.py
View file @
94cde109
...
@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
...
@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
class
BaseImageProcessor
(
ABC
):
class
BaseImageProcessor
(
ABC
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
@
abstractmethod
@
abstractmethod
async
def
process_images_async
(
self
,
image_data
,
**
kwargs
):
async
def
process_images_async
(
self
,
image_data
,
input_text
,
**
kwargs
):
pass
pass
class
DummyImageProcessor
(
BaseImageProcessor
):
class
DummyImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
):
pass
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
return
None
return
None
class
LlavaImageProcessor
(
BaseImageProcessor
):
class
LlavaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
_image_processor
=
_image_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
@
staticmethod
@
staticmethod
def
_process_single_image_task
(
def
_process_single_image_task
(
...
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
)
)
async
def
process_images_async
(
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
):
if
not
image_data
:
if
not
image_data
:
return
None
return
None
...
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
...
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
}
}
class
MllamaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
global_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
if
isinstance
(
input_text
,
list
):
assert
len
(
input_text
)
and
isinstance
(
input_text
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_text
)
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
if
len
(
image_data
)
>
0
:
images
=
[
load_image
(
image
)[
0
]
for
image
in
image_data
]
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"image_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
class
Qwen2VLImageProcessor
(
BaseImageProcessor
):
class
Qwen2VLImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
self
.
hf_config
=
hf_config
self
.
hf_config
=
hf_config
...
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
...
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return
self
.
_process_single_image_task
(
image_data
)
return
self
.
_process_single_image_task
(
image_data
)
async
def
process_images_async
(
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
):
if
not
image_data
:
if
not
image_data
:
return
None
return
None
...
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
...
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def
get_image_processor
(
def
get_image_processor
(
hf_config
,
server_args
:
ServerArgs
,
_image_
processor
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseImageProcessor
:
)
->
BaseImageProcessor
:
if
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
if
"MllamaForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
_image_processor
)
return
MllamaImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
else
:
else
:
return
LlavaImageProcessor
(
hf_config
,
server_args
,
_
image_processor
)
return
LlavaImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
def
get_dummy_image_processor
():
def
get_dummy_image_processor
():
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
94cde109
...
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
...
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
@@ -121,11 +122,12 @@ class ImageInputs:
...
@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs."""
"""The image related inputs."""
pixel_values
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_hash
es
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
@@ -138,19 +140,27 @@ class ImageInputs:
...
@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
pixel_values
=
obj
[
"pixel_values"
],
image_hash
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
image_hashes
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
image_grid_thws
=
obj
.
get
(
"image_grid_thws"
),
)
)
image_hash
=
ret
.
image_hash
image_hash
=
ret
.
image_hash
es
ret
.
pad_values
=
[
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
]
ret
.
image_sizes
=
obj
[
"image_sizes"
]
# Only when pixel values is not None we have modalities
optional_args
=
[
ret
.
modalities
=
obj
[
"modalities"
]
or
[
"image"
]
"image_sizes"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
return
ret
return
ret
...
@@ -416,6 +426,10 @@ class ScheduleBatch:
...
@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool
:
ReqToTokenPool
=
None
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
tree_cache
:
BasePrefixCache
=
None
# For utility
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
@@ -440,6 +454,12 @@ class ScheduleBatch:
...
@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
decoding_reqs
:
List
[
Req
]
=
None
decoding_reqs
:
List
[
Req
]
=
None
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Stream
# Stream
has_stream
:
bool
=
False
has_stream
:
bool
=
False
...
@@ -450,12 +470,20 @@ class ScheduleBatch:
...
@@ -450,12 +470,20 @@ class ScheduleBatch:
device
:
str
=
"cuda"
device
:
str
=
"cuda"
@
classmethod
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
,
model_config
,
):
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
...
@@ -493,7 +521,78 @@ class ScheduleBatch:
...
@@ -493,7 +521,78 @@ class ScheduleBatch:
return
out_cache_loc
return
out_cache_loc
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
for
req
in
self
.
reqs
:
im
=
req
.
image_inputs
if
im
is
None
or
im
.
num_image_tokens
is
None
:
# No image input
self
.
encoder_lens_cpu
.
append
(
0
)
self
.
encoder_cached
.
append
(
True
)
else
:
self
.
encoder_lens_cpu
.
append
(
im
.
num_image_tokens
)
self
.
encoder_cached
.
append
(
self
.
forward_mode
.
is_decode
()
or
len
(
req
.
prefix_indices
)
>=
im
.
num_image_tokens
)
self
.
encoder_lens
=
torch
.
tensor
(
self
.
encoder_lens_cpu
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
# Strip encoder infos
pt
=
0
decoder_out_cache_loc
=
[]
encoder_out_cache_loc
=
[]
for
i
,
req
in
enumerate
(
self
.
reqs
):
encoder_len
=
self
.
encoder_lens_cpu
[
i
]
seq_lens
[
i
]
-=
encoder_len
if
len
(
req
.
prefix_indices
)
<
encoder_len
:
# NOTE: the encoder part should considered as a whole
assert
len
(
req
.
prefix_indices
)
==
0
input_ids
[
i
]
=
input_ids
[
i
][
encoder_len
:]
encoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
encoder_len
])
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
+
encoder_len
:
pt
+
req
.
extend_input_len
]
)
self
.
extend_lens
[
i
]
-=
encoder_len
self
.
extend_num_tokens
-=
encoder_len
else
:
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
)
self
.
prefix_lens
[
i
]
-=
encoder_len
pt
+=
req
.
extend_input_len
# Reassign
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
if
not
decoder_out_cache_loc
:
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
out_cache_loc
=
torch
.
cat
(
decoder_out_cache_loc
)
if
not
encoder_out_cache_loc
:
self
.
encoder_out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
encoder_out_cache_loc
=
torch
.
cat
(
encoder_out_cache_loc
)
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
@@ -561,8 +660,13 @@ class ScheduleBatch:
...
@@ -561,8 +660,13 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
if
self
.
model_config
.
is_encoder_decoder
:
self
.
prepare_encoder_info_extend
(
input_ids
,
seq_lens
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
self
,
self
.
model_config
.
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
],
)
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
@@ -752,6 +856,10 @@ class ScheduleBatch:
...
@@ -752,6 +856,10 @@ class ScheduleBatch:
return
jump_forward_reqs
return
jump_forward_reqs
def
prepare_encoder_info_decode
(
self
):
# Reset the encoder cached status
self
.
encoder_cached
=
[
True
]
*
len
(
self
.
reqs
)
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
...
@@ -766,16 +874,22 @@ class ScheduleBatch:
...
@@ -766,16 +874,22 @@ class ScheduleBatch:
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
if
self
.
model_config
.
is_encoder_decoder
:
locs
=
self
.
encoder_lens
+
self
.
seq_lens
self
.
prepare_encoder_info_decode
()
else
:
locs
=
self
.
seq_lens
if
enable_overlap
:
if
enable_overlap
:
# Do not use in-place operations in the overlap mode
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
)
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
else
:
# A faster in-place version
# A faster in-place version
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
self
.
seq_lens_sum
+=
bs
...
@@ -802,6 +916,10 @@ class ScheduleBatch:
...
@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter
# No need to filter
return
return
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
self
.
encoder_lens
[
keep_indices
]
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
new_indices
=
torch
.
tensor
(
keep_indices
,
dtype
=
torch
.
int32
).
to
(
new_indices
=
torch
.
tensor
(
keep_indices
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
...
@@ -828,6 +946,11 @@ class ScheduleBatch:
...
@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
# needs to be called with pre-merged Batch.reqs.
self
.
sampling_info
.
merge_batch
(
other
.
sampling_info
)
self
.
sampling_info
.
merge_batch
(
other
.
sampling_info
)
# Encoder-decoder infos
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
torch
.
cat
([
self
.
encoder_lens
,
other
.
encoder_lens
])
self
.
encoder_lens_cpu
.
extend
(
other
.
encoder_lens_cpu
)
self
.
req_pool_indices
=
torch
.
concat
(
self
.
req_pool_indices
=
torch
.
concat
(
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
...
@@ -850,14 +973,11 @@ class ScheduleBatch:
...
@@ -850,14 +973,11 @@ class ScheduleBatch:
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
(
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
image_inputs
)
=
None
else
:
else
:
extend_seq_lens
=
self
.
extend_lens
extend_seq_lens
=
self
.
extend_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
if
self
.
has_regex
:
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
...
@@ -887,7 +1007,11 @@ class ScheduleBatch:
...
@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
],
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
self
.
encoder_out_cache_loc
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
mrope_positions_delta
=
mrope_positions_delta
,
...
@@ -897,6 +1021,7 @@ class ScheduleBatch:
...
@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
model_config
=
self
.
model_config
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
...
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
...
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
image_inputs
:
Optional
[
List
[
ImageInputs
]]
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
encoder_lens
:
Optional
[
torch
.
Tensor
]
encoder_lens_cpu
:
Optional
[
List
[
int
]]
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
# For LoRA
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
lora_paths
:
Optional
[
List
[
str
]]
...
...
python/sglang/srt/managers/scheduler.py
View file @
94cde109
...
@@ -662,8 +662,9 @@ class Scheduler:
...
@@ -662,8 +662,9 @@ class Scheduler:
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
)
)
new_batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
94cde109
...
@@ -122,7 +122,7 @@ class TokenizerManager:
...
@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
# We want to parallelize the image pre-processing so we create an executor for it
self
.
image_processor
=
get_image_processor
(
self
.
image_processor
=
get_image_processor
(
self
.
hf_config
,
server_args
,
self
.
processor
.
image_processor
self
.
hf_config
,
server_args
,
self
.
processor
)
)
else
:
else
:
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer
=
get_tokenizer
(
...
@@ -191,8 +191,10 @@ class TokenizerManager:
...
@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
top_logprobs_num
=
obj
.
top_logprobs_num
...
@@ -217,8 +219,10 @@ class TokenizerManager:
...
@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
obj
obj
.
image_data
[
index
],
input_text
or
input_ids
,
obj
)
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
index
]
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
...
@@ -263,8 +267,10 @@ class TokenizerManager:
...
@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
sampling_params
.
max_new_tokens
=
0
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
0
],
obj
obj
.
image_data
[
0
],
input_text
or
input_ids
,
obj
)
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
0
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
94cde109
...
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
...
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -41,13 +43,17 @@ class ReqToTokenPool:
...
@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
)
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
write_records
=
[]
self
.
write_records
=
[]
self
.
use_records
=
use_records
if
use_records
:
if
self
.
use_records
:
# records all write operations
self
.
write
=
self
.
write_with_records
self
.
write
=
self
.
write_with_records
else
:
else
:
self
.
write
=
self
.
write_without_records
self
.
write
=
self
.
write_without_records
def
write
(
self
,
indices
,
values
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
available_size
(
self
):
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
return
len
(
self
.
free_slots
)
...
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
...
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
...
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
cache_v
.
dtype
!=
self
.
dtype
:
if
cache_v
.
dtype
!=
self
.
dtype
:
...
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
...
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
...
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
):
# NOTE(Andy): ignore the dtype check
# NOTE(Andy): ignore the dtype check
layer_id
=
layer
.
layer_id
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
94cde109
...
@@ -105,6 +105,7 @@ class CudaGraphRunner:
...
@@ -105,6 +105,7 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
None
self
.
graph_memory_pool
=
None
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
# Batch sizes to capture
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
@@ -132,6 +133,9 @@ class CudaGraphRunner:
...
@@ -132,6 +133,9 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
if
self
.
use_torch_compile
:
if
self
.
use_torch_compile
:
set_torch_compile_config
()
set_torch_compile_config
()
...
@@ -144,9 +148,18 @@ class CudaGraphRunner:
...
@@ -144,9 +148,18 @@ class CudaGraphRunner:
)
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
encoder_len_fill_value
,
dtype
=
torch
.
int32
)
else
:
self
.
encoder_lens
=
None
# Capture
# Capture
try
:
try
:
self
.
capture
()
with
self
.
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
"
...
@@ -157,11 +170,32 @@ class CudaGraphRunner:
...
@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
)
def
can_run
(
self
,
batch_size
:
int
):
@
contextmanager
if
self
.
disable_padding
:
def
model_capture_mode
(
self
):
return
batch_size
in
self
.
graphs
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
else
:
self
.
model_runner
.
model
.
capture_mode
=
True
return
batch_size
<=
self
.
max_bs
yield
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported
=
(
torch
.
all
(
forward_batch
.
encoder_lens
>
0
)
if
self
.
is_encoder_decoder
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
def
capture
(
self
):
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
...
@@ -188,11 +222,19 @@ class CudaGraphRunner:
...
@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
bs
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
)
)
# Run and capture
# Run and capture
...
@@ -208,6 +250,7 @@ class CudaGraphRunner:
...
@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend
=
self
.
model_runner
.
attn_backend
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
...
@@ -251,6 +294,8 @@ class CudaGraphRunner:
...
@@ -251,6 +294,8 @@ class CudaGraphRunner:
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
@@ -258,6 +303,7 @@ class CudaGraphRunner:
...
@@ -258,6 +303,7 @@ class CudaGraphRunner:
self
.
req_pool_indices
,
self
.
req_pool_indices
,
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
self
.
encoder_lens
,
)
)
# Replay
# Replay
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
94cde109
...
@@ -108,6 +108,12 @@ class ForwardBatch:
...
@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
=
None
image_inputs
:
Optional
[
List
[
ImageInputs
]]
=
None
# Encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# For LoRA
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
lora_paths
:
Optional
[
List
[
str
]]
=
None
...
@@ -194,6 +200,11 @@ class ForwardBatch:
...
@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
image_inputs
=
batch
.
image_inputs
,
encoder_cached
=
batch
.
encoder_cached
,
encoder_lens
=
batch
.
encoder_lens
,
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
...
@@ -212,11 +223,11 @@ class ForwardBatch:
...
@@ -212,11 +223,11 @@ class ForwardBatch:
],
],
axis
=
0
,
axis
=
0
,
)
)
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_seq_lens
=
torch
.
tensor
(
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
94cde109
...
@@ -270,7 +270,6 @@ class ModelRunner:
...
@@ -270,7 +270,6 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
else
None
)
)
self
.
has_cross_attention
=
getattr
(
self
.
model
,
"has_cross_attention"
,
False
)
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
)
...
@@ -510,7 +509,7 @@ class ModelRunner:
...
@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
assert
not
self
.
has_cross_attention
,
(
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
...
@@ -558,9 +557,7 @@ class ModelRunner:
...
@@ -558,9 +557,7 @@ class ModelRunner:
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
forward_batch
.
batch_size
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
...
...
python/sglang/srt/models/mllama.py
0 → 100644
View file @
94cde109
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/qwen2_vl.py
View file @
94cde109
...
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
]
]
positions
=
forward_batch
.
mrope_positions
positions
=
forward_batch
.
mrope_positions
if
image_inputs
is
None
or
len
(
image_inputs
)
==
0
:
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
len
(
image_inputs
)
==
0
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
...
...
python/sglang/srt/utils.py
View file @
94cde109
...
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
...
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
):
):
return
True
return
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment