Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
94cde109
Unverified
Commit
94cde109
authored
Oct 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 21, 2024
Browse files
Llama3.2 vision model support (#1551)
parent
00611286
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1536 additions
and
118 deletions
+1536
-118
python/pyproject.toml
python/pyproject.toml
+28
-9
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-1
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+1
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+13
-0
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+11
-4
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+17
-5
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+121
-36
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+17
-5
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+72
-16
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+148
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+10
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+53
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+12
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-5
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1004
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+5
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
python/pyproject.toml
View file @
94cde109
...
...
@@ -8,16 +8,12 @@ version = "0.3.4"
description
=
"SGLang is yet another fast serving framework for large language models and vision language models."
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
]
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
,
]
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
]
[project.optional-dependencies]
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
...
...
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
,
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
...
...
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug
Tracker"
=
"https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
[tool.wheel]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
]
exclude
=
[
"assets*"
,
"benchmark*"
,
"docs*"
,
"dist*"
,
"playground*"
,
"scripts*"
,
"tests*"
,
]
python/sglang/bench_latency.py
View file @
94cde109
...
...
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
...
...
python/sglang/lang/chat_template.py
View file @
94cde109
...
...
@@ -229,6 +229,7 @@ register_chat_template(
),
},
stop_str
=
(
"<|eot_id|>"
,),
image_token
=
"<|image|>"
,
)
)
...
...
python/sglang/srt/configs/model_config.py
View file @
94cde109
...
...
@@ -89,6 +89,8 @@ class ModelConfig:
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
is_encoder_decoder
=
self
.
hf_config
.
model_type
in
[
"mllama"
]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
...
...
python/sglang/srt/conversation.py
View file @
94cde109
...
...
@@ -509,6 +509,19 @@ register_conv_template(
)
)
register_conv_template
(
Conversation
(
name
=
"llama_3_vision"
,
system_message
=
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
,
system_template
=
"<|start_header_id|>system<|end_header_id|>
\n\n
{system_message}<|eot_id|>"
,
roles
=
(
"user"
,
"assistant"
),
sep_style
=
SeparatorStyle
.
LLAMA3
,
sep
=
""
,
stop_str
=
[
"<|end_of_text|>"
,
"<|eot_id|>"
],
image_token
=
"<|image|>"
,
)
)
register_conv_template
(
Conversation
(
name
=
"llava_llama_3"
,
...
...
python/sglang/srt/layers/attention/__init__.py
View file @
94cde109
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
from
torch
import
nn
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run forward on an attention layer."""
...
...
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for decode."""
...
...
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for extend."""
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
94cde109
...
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
...
...
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
...
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
...
...
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
...
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
94cde109
...
...
@@ -11,7 +11,6 @@ from enum import Enum, auto
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
...
...
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
if
is_flashinfer_available
():
...
...
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
and
model_runner
.
model_config
.
is_encoder_decoder
),
"Sliding window and cross attention are not supported together"
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
elif
model_runner
.
model_config
.
is_encoder_decoder
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
else
:
...
...
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
else
:
...
...
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
prefix_lens
,
use_ragged
,
use_ragged
=
use_ragged
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
cuda_graph_kv_indices
=
torch
.
zeros
(
...
...
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
]
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
...
...
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
(
decode_wrappers
,)
...
...
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
torch
.
Tensor
=
None
,
):
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
self
.
cuda_graph_metadata
[
bs
],
decode_wrappers
=
self
.
cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
self
.
_get_wrapper_idx
(
layer
)
]
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
True
,
causal
=
not
layer
.
is_cross_attention
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
layer
.
logit_cap
,
...
...
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
k
is
not
None
:
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
...
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
if
self
.
num_wrappers
==
1
:
return
0
...
...
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
...
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
self
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
...
...
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
...
...
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
update_cross_attention
(
self
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
None
,
encoder_lens
=
None
,
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# Normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# Cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
seq_lens_sum
=
encoder_lens
.
sum
().
item
()
self
.
call_begin_forward
(
decode_wrappers
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens_sum
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
)
def
call_begin_forward
(
self
,
...
...
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
...
...
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
elif
self
.
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
assert
self
.
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
update_cross_attention
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
,
encoder_lens
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# normal attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
encoder_lens
else
:
# cross attention
paged_kernel_lens
=
encoder_lens
kv_start_idx
=
torch
.
zeros_like
(
encoder_lens
)
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
)
def
call_begin_forward
(
self
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
94cde109
...
...
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
...
...
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
...
...
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
...
...
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
...
...
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
...
...
python/sglang/srt/managers/image_processor.py
View file @
94cde109
...
...
@@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs):
class
BaseImageProcessor
(
ABC
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
@
abstractmethod
async
def
process_images_async
(
self
,
image_data
,
**
kwargs
):
async
def
process_images_async
(
self
,
image_data
,
input_text
,
**
kwargs
):
pass
class
DummyImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
):
pass
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
return
None
class
LlavaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
self
.
hf_config
=
hf_config
self
.
_image_processor
=
_image_processor
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
max_workers
=
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
()),
)
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
...
...
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
if
not
image_data
:
return
None
...
...
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
}
class
MllamaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
global_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
if
isinstance
(
input_text
,
list
):
assert
len
(
input_text
)
and
isinstance
(
input_text
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_text
)
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
if
len
(
image_data
)
>
0
:
images
=
[
load_image
(
image
)[
0
]
for
image
in
image_data
]
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"image_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
class
Qwen2VLImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_image_processor
):
self
.
hf_config
=
hf_config
...
...
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return
self
.
_process_single_image_task
(
image_data
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
request_obj
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
):
if
not
image_data
:
return
None
...
...
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def
get_image_processor
(
hf_config
,
server_args
:
ServerArgs
,
_image_
processor
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseImageProcessor
:
if
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
_image_processor
)
if
"MllamaForConditionalGeneration"
in
hf_config
.
architectures
:
return
MllamaImageProcessor
(
hf_config
,
server_args
,
processor
)
elif
"Qwen2VLForConditionalGeneration"
in
hf_config
.
architectures
:
return
Qwen2VLImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
else
:
return
LlavaImageProcessor
(
hf_config
,
server_args
,
_
image_processor
)
return
LlavaImageProcessor
(
hf_config
,
server_args
,
processor
.
image_processor
)
def
get_dummy_image_processor
():
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
94cde109
...
...
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import
torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
...
@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs."""
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_hash
es
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
...
...
@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
image_hash
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
image_grid_thws
=
obj
.
get
(
"image_grid_thws"
),
image_hashes
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
)
image_hash
=
ret
.
image_hash
image_hash
=
ret
.
image_hash
es
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
ret
.
image_sizes
=
obj
[
"image_sizes"
]
# Only when pixel values is not None we have modalities
ret
.
modalities
=
obj
[
"modalities"
]
or
[
"image"
]
optional_args
=
[
"image_sizes"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
return
ret
...
...
@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
# For utility
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens
:
int
=
None
decoding_reqs
:
List
[
Req
]
=
None
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Stream
has_stream
:
bool
=
False
...
...
@@ -450,12 +470,20 @@ class ScheduleBatch:
device
:
str
=
"cuda"
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
,
model_config
,
):
return
cls
(
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
),
...
...
@@ -493,7 +521,78 @@ class ScheduleBatch:
return
out_cache_loc
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
for
req
in
self
.
reqs
:
im
=
req
.
image_inputs
if
im
is
None
or
im
.
num_image_tokens
is
None
:
# No image input
self
.
encoder_lens_cpu
.
append
(
0
)
self
.
encoder_cached
.
append
(
True
)
else
:
self
.
encoder_lens_cpu
.
append
(
im
.
num_image_tokens
)
self
.
encoder_cached
.
append
(
self
.
forward_mode
.
is_decode
()
or
len
(
req
.
prefix_indices
)
>=
im
.
num_image_tokens
)
self
.
encoder_lens
=
torch
.
tensor
(
self
.
encoder_lens_cpu
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
# Strip encoder infos
pt
=
0
decoder_out_cache_loc
=
[]
encoder_out_cache_loc
=
[]
for
i
,
req
in
enumerate
(
self
.
reqs
):
encoder_len
=
self
.
encoder_lens_cpu
[
i
]
seq_lens
[
i
]
-=
encoder_len
if
len
(
req
.
prefix_indices
)
<
encoder_len
:
# NOTE: the encoder part should considered as a whole
assert
len
(
req
.
prefix_indices
)
==
0
input_ids
[
i
]
=
input_ids
[
i
][
encoder_len
:]
encoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
encoder_len
])
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
+
encoder_len
:
pt
+
req
.
extend_input_len
]
)
self
.
extend_lens
[
i
]
-=
encoder_len
self
.
extend_num_tokens
-=
encoder_len
else
:
decoder_out_cache_loc
.
append
(
self
.
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
)
self
.
prefix_lens
[
i
]
-=
encoder_len
pt
+=
req
.
extend_input_len
# Reassign
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
if
not
decoder_out_cache_loc
:
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
out_cache_loc
=
torch
.
cat
(
decoder_out_cache_loc
)
if
not
encoder_out_cache_loc
:
self
.
encoder_out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
else
:
self
.
encoder_out_cache_loc
=
torch
.
cat
(
encoder_out_cache_loc
)
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
...
...
@@ -561,8 +660,13 @@ class ScheduleBatch:
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
if
self
.
model_config
.
is_encoder_decoder
:
self
.
prepare_encoder_info_extend
(
input_ids
,
seq_lens
)
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
self
,
self
.
model_config
.
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
],
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
...
@@ -752,6 +856,10 @@ class ScheduleBatch:
return
jump_forward_reqs
def
prepare_encoder_info_decode
(
self
):
# Reset the encoder cached status
self
.
encoder_cached
=
[
True
]
*
len
(
self
.
reqs
)
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
DECODE
...
...
@@ -766,16 +874,22 @@ class ScheduleBatch:
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
if
self
.
model_config
.
is_encoder_decoder
:
locs
=
self
.
encoder_lens
+
self
.
seq_lens
self
.
prepare_encoder_info_decode
()
else
:
locs
=
self
.
seq_lens
if
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
self
.
seq_lens
=
self
.
seq_lens
+
1
else
:
# A faster in-place version
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_len
s
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
loc
s
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
...
...
@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter
return
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
self
.
encoder_lens
[
keep_indices
]
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
new_indices
=
torch
.
tensor
(
keep_indices
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
...
...
@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
self
.
sampling_info
.
merge_batch
(
other
.
sampling_info
)
# Encoder-decoder infos
if
self
.
model_config
.
is_encoder_decoder
:
self
.
encoder_lens
=
torch
.
cat
([
self
.
encoder_lens
,
other
.
encoder_lens
])
self
.
encoder_lens_cpu
.
extend
(
other
.
encoder_lens_cpu
)
self
.
req_pool_indices
=
torch
.
concat
(
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
...
...
@@ -850,14 +973,11 @@ class ScheduleBatch:
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
(
image_inputs
)
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
extend_seq_lens
=
self
.
extend_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
...
...
@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
],
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
self
.
encoder_out_cache_loc
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
...
...
@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
model_config
=
self
.
model_config
,
forward_mode
=
self
.
forward_mode
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
...
...
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
# For encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
encoder_lens
:
Optional
[
torch
.
Tensor
]
encoder_lens_cpu
:
Optional
[
List
[
int
]]
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
...
...
python/sglang/srt/managers/scheduler.py
View file @
94cde109
...
...
@@ -662,8 +662,9 @@ class Scheduler:
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
)
new_batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
94cde109
...
...
@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self
.
image_processor
=
get_image_processor
(
self
.
hf_config
,
server_args
,
self
.
processor
.
image_processor
self
.
hf_config
,
server_args
,
self
.
processor
)
else
:
self
.
tokenizer
=
get_tokenizer
(
...
...
@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
...
...
@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
obj
obj
.
image_data
[
index
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
...
...
@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
0
],
obj
obj
.
image_data
[
0
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
94cde109
...
...
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
self
.
free_slots
=
list
(
range
(
size
))
self
.
write_records
=
[]
self
.
use_records
=
use_records
if
use_records
:
# records all write operations
if
self
.
use_records
:
self
.
write
=
self
.
write_with_records
else
:
self
.
write
=
self
.
write_without_records
def
write
(
self
,
indices
,
values
):
# Keep the signature for type checking, will be initialized during runtime
raise
NotImplementedError
()
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
...
...
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
...
...
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
cache_v
.
dtype
!=
self
.
dtype
:
...
...
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
...
...
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def
set_kv_buffer
(
self
,
layer
_id
:
int
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
# NOTE(Andy): ignore the dtype check
layer_id
=
layer
.
layer_id
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
94cde109
...
...
@@ -105,6 +105,7 @@ class CudaGraphRunner:
self
.
graph_memory_pool
=
None
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
self
.
model_runner
.
model_config
.
is_encoder_decoder
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
...
...
@@ -132,6 +133,9 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
if
self
.
use_torch_compile
:
set_torch_compile_config
()
...
...
@@ -144,9 +148,18 @@ class CudaGraphRunner:
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
encoder_len_fill_value
,
dtype
=
torch
.
int32
)
else
:
self
.
encoder_lens
=
None
# Capture
try
:
self
.
capture
()
with
self
.
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
...
...
@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
def
can_run
(
self
,
batch_size
:
int
):
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
else
:
return
batch_size
<=
self
.
max_bs
@
contextmanager
def
model_capture_mode
(
self
):
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
True
yield
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
is_bs_supported
=
(
forward_batch
.
batch_size
in
self
.
graphs
if
self
.
disable_padding
else
forward_batch
.
batch_size
<=
self
.
max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported
=
(
torch
.
all
(
forward_batch
.
encoder_lens
>
0
)
if
self
.
is_encoder_decoder
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
...
...
@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
bs
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
)
# Run and capture
...
...
@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
),
...
...
@@ -251,6 +294,8 @@ class CudaGraphRunner:
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -258,6 +303,7 @@ class CudaGraphRunner:
self
.
req_pool_indices
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
self
.
encoder_lens
,
)
# Replay
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
94cde109
...
...
@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
=
None
# Encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_lens_cpu
:
Optional
[
List
[
int
]]
=
None
encoder_out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
...
...
@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
image_inputs
=
batch
.
image_inputs
,
encoder_cached
=
batch
.
encoder_cached
,
encoder_lens
=
batch
.
encoder_lens
,
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
encoder_out_cache_loc
=
batch
.
encoder_out_cache_loc
,
seq_lens_sum
=
batch
.
seq_lens_sum
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
...
...
@@ -212,11 +223,11 @@ class ForwardBatch:
],
axis
=
0
,
)
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
94cde109
...
...
@@ -270,7 +270,6 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
)
self
.
has_cross_attention
=
getattr
(
self
.
model
,
"has_cross_attention"
,
False
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
...
...
@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
has_cross_attention
,
(
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
...
...
@@ -558,9 +557,7 @@ class ModelRunner:
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
.
batch_size
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
...
...
python/sglang/srt/models/mllama.py
0 → 100644
View file @
94cde109
# Adapted from:
# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py
"""PyTorch Mllama model."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
transformers.models.mllama.configuration_mllama
as
config_mllama
import
vllm.distributed.parallel_state
as
ps
from
torch
import
nn
from
transformers.modeling_outputs
import
BaseModelOutput
,
CausalLMOutputWithPast
from
transformers.models.mllama.modeling_mllama
import
(
_prepare_aspect_ratio_attention_mask
,
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaDecoderLayer
,
LlamaMLP
class
ColumnParallelConv2dPatch
(
torch
.
nn
.
Module
):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
]],
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
_unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
stride
)
self
.
_linear
=
ColumnParallelLinear
(
in_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
out_channels
,
bias
=
bias
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
_unfold
(
x
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
,
_
=
self
.
_linear
(
x
)
return
x
class
MllamaPrecomputedAspectRatioEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
True
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
is_gated
=
is_gated
self
.
embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
hidden_size
)
if
is_gated
:
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embeddings
=
self
.
embedding
(
aspect_ratio_ids
)
embeddings
=
embeddings
.
reshape
(
-
1
,
self
.
max_num_tiles
,
1
,
self
.
hidden_size
)
if
self
.
is_gated
:
embeddings
=
embeddings
*
self
.
gate
.
tanh
()
hidden_state
=
hidden_state
+
embeddings
return
hidden_state
class
MllamaPrecomputedPositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
num_patches
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
+
1
self
.
hidden_size
=
config
.
hidden_size
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
# position embedding
position_embedding
=
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
self
.
embedding
=
nn
.
Parameter
(
self
.
scale
*
position_embedding
)
# tile position embedding
self
.
tile_embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
num_patches
*
self
.
hidden_size
,
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# position embeddings
gated_position_embedding
=
(
1
-
self
.
gate
.
tanh
())
*
self
.
embedding
hidden_state
=
hidden_state
+
gated_position_embedding
.
view
(
1
,
1
,
self
.
num_patches
,
self
.
hidden_size
)
# precomputed tile position embeddings
tile_position_embedding
=
self
.
tile_embedding
(
aspect_ratio_ids
)
batch_size
=
hidden_state
.
shape
[
0
]
tile_position_embedding
=
tile_position_embedding
.
reshape
(
batch_size
,
self
.
max_num_tiles
,
self
.
num_patches
,
self
.
hidden_size
)
gated_tile_position_embedding
=
self
.
gate
.
tanh
()
*
tile_position_embedding
hidden_state
=
hidden_state
+
gated_tile_position_embedding
return
hidden_state
class
MllamaVisionSdpaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
model_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
False
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_state
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
q
.
shape
[
0
],
q
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
k
.
shape
[
0
],
k
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
v
.
shape
[
0
],
v
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# TODO: remove padding in image encoder
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
attn_output
.
shape
[
0
],
attn_output
.
shape
[
1
],
-
1
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MllamaVisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
MllamaVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
False
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
attention_heads
self
.
is_gated
=
is_gated
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
)
self
.
mlp
=
MllamaVisionMLP
(
config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
# there used to be an if else here, no code path
if
is_gated
:
self
.
gate_attn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
self
.
gate_ffn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
attention_mask
=
attention_mask
)
gate_attn
=
1
if
not
self
.
is_gated
else
self
.
gate_attn
.
tanh
()
hidden_state
=
residual
+
gate_attn
*
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
gate_ffn
=
1
if
not
self
.
is_gated
else
self
.
gate_ffn
.
tanh
()
hidden_state
=
residual
+
gate_ffn
*
hidden_state
return
hidden_state
class
MllamaVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
num_layers
=
32
,
is_gated
=
False
,
output_hidden_states
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
(
[
MllamaVisionEncoderLayer
(
config
,
is_gated
)
for
_
in
range
(
num_layers
)]
)
self
.
output_hidden_states
=
output_hidden_states
or
[]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
encoder_states
=
()
for
i
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
i
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
)
if
len
(
self
.
layers
)
-
1
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
return
hidden_states
,
encoder_states
class
MllamaVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
in_channels
=
config
.
num_channels
self
.
intermediate_layers_indices
=
config
.
intermediate_layers_indices
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
ColumnParallelConv2dPatch
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
hidden_size
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
gated_positional_embedding
=
MllamaPrecomputedPositionEmbedding
(
config
)
self
.
pre_tile_positional_embedding
=
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
self
.
post_tile_positional_embedding
=
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
)
# encoders
self
.
transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_hidden_layers
,
is_gated
=
False
,
output_hidden_states
=
config
.
intermediate_layers_indices
,
)
self
.
global_transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_global_layers
,
is_gated
=
True
)
def
apply_class_embedding
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
_
,
hidden_size
=
hidden_state
.
shape
class_embedding
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
hidden_size
)
hidden_state
=
torch
.
cat
([
class_embedding
,
hidden_state
],
dim
=
1
)
return
hidden_state
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
,
aspect_ratio_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch_size
,
num_concurrent_media
,
num_tiles
,
num_channels
,
height
,
width
=
(
pixel_values
.
shape
)
pixel_values
=
pixel_values
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_channels
,
height
,
width
)
aspect_ratio_ids
=
aspect_ratio_ids
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
# patch embedding
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
self
.
layernorm_pre
.
weight
.
dtype
)
)
hidden_state
=
patch_embeds
hidden_state
=
ps
.
get_tp_group
().
all_gather
(
hidden_state
)
# tile embeddings
_
,
num_patches
,
dim
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
-
1
,
dim
)
hidden_state
=
self
.
pre_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply cls token
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
apply_class_embedding
(
hidden_state
)
num_patches
+=
1
# apply position embeddings
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
gated_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply encoder
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
# Compute the number of tokens to pad
num_padding_patches
=
(
8
-
(
hidden_state
.
shape
[
-
2
]
%
8
))
%
8
# Compute padding tuple for pad function
padding
=
(
0
,
0
,
0
,
num_padding_patches
,
)
# (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state
=
F
.
pad
(
hidden_state
,
padding
,
mode
=
"constant"
,
value
=
0
)
slice_index
=
-
num_padding_patches
if
num_padding_patches
>
0
else
None
attention_mask
=
aspect_ratio_mask
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
attention_mask
=
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
=
attention_mask
,
num_patches
=
self
.
num_patches
,
target_length
=
hidden_state
.
shape
[
2
],
dtype
=
self
.
layernorm_pre
.
weight
.
dtype
,
)
hidden_state
=
hidden_state
.
view
(
batch_size
*
num_concurrent_media
,
-
1
,
dim
)
output
=
self
.
transformer
(
hidden_state
,
attention_mask
=
attention_mask
,
)
hidden_state
,
intermediate_hidden_states
=
output
[
0
],
output
[
1
]
intermediate_hidden_states
=
torch
.
stack
(
intermediate_hidden_states
,
dim
=-
1
)
# apply global encoder
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
,
)
hidden_state
=
self
.
post_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
*
(
num_patches
+
num_padding_patches
),
dim
,
)
hidden_state
=
self
.
global_transformer
(
hidden_state
,
attention_mask
=
attention_mask
)[
0
]
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
,
)
hidden_state
=
hidden_state
[:,
:,
:
slice_index
]
# adding intermediate layer outputs
hidden_state
=
hidden_state
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
-
1
,
)
intermediate_hidden_states
=
intermediate_hidden_states
[:,
:,
:
slice_index
]
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
-
1
)
hidden_state
=
torch
.
cat
([
hidden_state
,
intermediate_hidden_states
],
dim
=-
1
)
return
hidden_state
class
MllamaTextRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
variance_epsilon
}
"
class
MllamaTextCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
layer_id
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
model_parallel_size
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_local_key_value_heads
=
(
self
.
num_key_value_heads
//
self
.
model_parallel_size
)
self
.
dropout
=
config
.
dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
self
.
layer_id
=
layer_id
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
q_local_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_local_size
=
self
.
num_local_key_value_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
num_heads
,
self
.
num_key_value_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self
.
q_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
RadixAttention
(
self
.
num_local_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_local_key_value_heads
,
layer_id
=
layer_id
,
is_cross_attention
=
True
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
if
cross_attention_states
is
None
:
k
=
None
v
=
None
else
:
qkv_enc
,
_
=
self
.
qkv_proj
(
cross_attention_states
)
_
,
k
,
v
=
qkv_enc
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
k
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
],
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
cross_attn
=
MllamaTextCrossAttention
(
config
=
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_attn_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
mlp
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_mlp_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
full_text_row_masked_out_mask
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
cross_attention_states
=
cross_attention_states
,
forward_batch
=
forward_batch
,
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
()
*
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_mlp_gate
.
tanh
()
*
hidden_states
return
hidden_states
class
MllamaTextModel
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"model"
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
=
None
,
):
super
().
__init__
()
self
.
padding_id
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
+
8
,
config
.
hidden_size
)
self
.
cross_attention_layers
=
config
.
cross_attention_layers
layers
=
[]
for
layer_id
in
range
(
config
.
num_hidden_layers
):
if
layer_id
in
self
.
cross_attention_layers
:
layers
.
append
(
MllamaCrossAttentionDecoderLayer
(
config
,
layer_id
,
quant_config
=
quant_config
)
)
else
:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers
.
append
(
LlamaDecoderLayer
(
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
)
)
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
forward_batch
:
ForwardBatch
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
for
_
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
forward_batch
=
forward_batch
,
)
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
residual
=
None
,
)
hidden_states
=
hidden_states
+
residual
else
:
raise
ValueError
(
f
"Unknown decoder layer type
{
type
(
decoder_layer
)
}
"
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MllamaForCausalLM
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"language_model"
_no_split_modules
=
[
"MllamaCrossAttentionDecoderLayer"
,
"MllamaSelfAttentionDecoderLayer"
,
]
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
cache_config
=
None
,
):
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
MllamaTextModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
forward_batch
:
ForwardBatch
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
forward_batch
=
forward_batch
,
skip_cross_attention
=
skip_cross_attention
,
)
return
hidden_states
class
MllamaForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
):
super
().
__init__
()
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
max_num_tiles
=
config
.
vision_config
.
max_num_tiles
self
.
vision_output_dim
=
config
.
vision_config
.
vision_output_dim
self
.
pad_token_id
=
(
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
else
-
1
)
self
.
image_size
=
config
.
vision_config
.
image_size
self
.
vision_model
=
MllamaVisionModel
(
config
.
vision_config
)
self
.
language_model
=
MllamaForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
multi_modal_projector
=
nn
.
Linear
(
config
.
vision_config
.
vision_output_dim
,
config
.
text_config
.
hidden_size
,
bias
=
True
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
capture_mode
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
pixel_values
=
image_inputs
.
pixel_values
pad_values
=
image_inputs
.
pad_values
num_concurrent_media
,
num_tiles
=
pixel_values
.
shape
[
1
:
3
]
num_patches
=
self
.
vision_model
.
num_patches
image_len
=
num_concurrent_media
*
num_tiles
*
num_patches
image_inputs
.
num_image_tokens
=
image_len
pad_ids
=
pad_values
*
((
image_len
+
len
(
pad_values
))
//
len
(
pad_values
))
return
pad_ids
[:
image_len
]
+
input_ids
def
_batch_image_inputs
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
()
or
all
(
forward_batch
.
encoder_cached
):
return
None
,
None
,
None
,
None
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images
=
max_num_tiles
=
bs
=
0
for
i
,
im
in
enumerate
(
forward_batch
.
image_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
im
is
not
None
:
max_num_images
=
max
(
max_num_images
,
im
.
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
im
.
pixel_values
.
shape
[
2
])
bs
+=
1
if
max_num_images
*
max_num_tiles
*
bs
==
0
:
return
None
,
None
,
None
,
None
with
forward_batch
.
out_cache_loc
.
device
:
batched_images
=
torch
.
zeros
(
bs
,
max_num_images
,
max_num_tiles
,
3
,
self
.
image_size
,
self
.
image_size
,
dtype
=
torch
.
float32
,
)
batched_ar_ids
=
torch
.
ones
(
bs
,
max_num_images
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
batched_ar_mask
=
torch
.
zeros
(
bs
,
max_num_images
,
max_num_tiles
,
dtype
=
torch
.
int64
)
i
=
0
encoder_lens_need
=
[]
for
k
,
im
in
enumerate
(
forward_batch
.
image_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
im
is
None
:
continue
encoder_lens_need
.
append
(
forward_batch
.
encoder_lens
[
k
])
for
j
in
range
(
im
.
pixel_values
.
shape
[
1
]):
img
=
im
.
pixel_values
[
0
,
j
]
num_tiles
=
img
.
shape
[
0
]
batched_images
[
i
,
j
,
:
num_tiles
]
=
img
batched_ar_ids
[
i
,
j
]
=
im
.
aspect_ratio_ids
[
0
,
j
]
batched_ar_mask
[
i
,
j
,
:
num_tiles
]
=
im
.
aspect_ratio_mask
[
0
,
j
]
i
+=
1
return
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
encoder_lens_need
:
List
[
int
]
):
# NOTE: not all encoders need computation, some are cached
head_dim
=
cross_attention_states
.
shape
[
-
1
]
total_encoder_len
=
sum
(
encoder_lens_need
)
cross_attention_states_flat
=
torch
.
zeros
(
total_encoder_len
,
head_dim
,
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
,
)
i
=
start_pos
=
0
for
encoder_len
in
encoder_lens_need
:
if
encoder_len
==
0
:
continue
end_pos
=
start_pos
+
encoder_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
cross_attention_states
[
i
][
:
encoder_len
]
i
+=
1
start_pos
+=
encoder_len
return
cross_attention_states_flat
def
get_full_text_row_masked_out_mask
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
full_text_row_masked_out_mask
=
forward_batch
.
encoder_lens
!=
0
else
:
full_text_row_masked_out_mask
=
torch
.
ones
(
forward_batch
.
extend_seq_lens
.
sum
(),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_len
in
zip
(
forward_batch
.
seq_lens
.
tolist
(),
forward_batch
.
encoder_lens_cpu
):
if
encoder_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
(
False
)
start_pos
+=
encoder_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
forward_batch
.
seq_lens
.
device
)
return
full_text_row_masked_out_mask
.
reshape
(
-
1
,
1
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
self
.
_batch_image_inputs
(
forward_batch
)
)
# TODO: support multi-image by this mask
cross_attention_mask
=
None
cross_attention_states
=
None
if
self
.
capture_mode
:
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
# Make is a constant value to avoid cuda graph capture issue
skip_cross_attention
=
False
else
:
# NOTE: we do not need image_inputs when prefill
assert
len
(
forward_batch
.
encoder_lens
)
==
len
(
forward_batch
.
seq_lens
)
assert
len
(
forward_batch
.
encoder_lens_cpu
)
==
len
(
forward_batch
.
seq_lens
)
skip_cross_attention
=
forward_batch
.
encoder_lens
.
max
()
==
0
if
not
skip_cross_attention
:
full_text_row_masked_out_mask
=
self
.
get_full_text_row_masked_out_mask
(
forward_batch
)
else
:
full_text_row_masked_out_mask
=
None
if
batched_images
is
not
None
:
# NOTE: llama's reference implementation runs vision model on CPU
cross_attention_states
=
self
.
vision_model
(
batched_images
,
batched_ar_ids
,
batched_ar_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bs
,
_
,
_
,
_
,
image_token_dim
=
cross_attention_states
.
shape
cross_attention_states
=
cross_attention_states
.
view
(
bs
,
-
1
,
image_token_dim
)
cross_attention_states
=
self
.
flat_encoder_result
(
cross_attention_states
,
encoder_lens_need
)
hidden_states
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
forward_batch
=
forward_batch
,
skip_cross_attention
=
skip_cross_attention
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
language_model
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"patch_embedding.weight"
in
name
:
name
=
name
.
replace
(
"patch_embedding.weight"
,
"patch_embedding._linear.weight"
)
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
.
shape
[
0
],
-
1
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
.
pop
(
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
MllamaForConditionalGeneration
python/sglang/srt/models/qwen2_vl.py
View file @
94cde109
...
...
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
]
positions
=
forward_batch
.
mrope_positions
if
image_inputs
is
None
or
len
(
image_inputs
)
==
0
:
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
len
(
image_inputs
)
==
0
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
...
...
python/sglang/srt/utils.py
View file @
94cde109
...
...
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
):
return
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment