Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea93079b
Unverified
Commit
ea93079b
authored
Aug 02, 2025
by
Wenchen Lo
Committed by
GitHub
Aug 02, 2025
Browse files
model: adapt mllama4 to VisionAttention (#8512)
Co-authored-by:
root
<
mickjagger19@icloud.com
>
parent
4bec99ec
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
518 additions
and
52 deletions
+518
-52
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+25
-10
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+27
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+23
-8
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+11
-2
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+428
-19
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+4
-3
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
ea93079b
...
...
@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""
import
contextlib
import
logging
import
os
import
warnings
from
pathlib
import
Path
...
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
)
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
,
lru_cache_frozenset
from
sglang.srt.utils
import
is_remote_url
,
logger
,
lru_cache_frozenset
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
...
@@ -317,15 +316,31 @@ def get_processor(
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
kwargs
[
"use_fast"
]
=
use_fast
try
:
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
tokenizer_name
}
does not have a slow version. Automatically use fast version"
)
kwargs
[
"use_fast"
]
=
True
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
else
:
raise
e
tokenizer
=
get_tokenizer_from_processor
(
processor
)
attach_additional_stop_token_ids
(
tokenizer
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
ea93079b
...
...
@@ -4,7 +4,7 @@ import dataclasses
import
functools
import
math
from
functools
import
lru_cache
,
partial
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
output
=
flash_attn_varlen_func
(
q
,
k
,
...
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias
:
bool
=
True
,
qk_normalization
:
bool
=
False
,
layer_norm_eps
:
float
=
1e-06
,
customized_position_embedding_applier
:
Callable
[
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
=
None
,
**
kwargs
,
):
super
().
__init__
()
...
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
)
# priority: server_args > passed qkv_backend > sdpa
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
qkv_backend
is
None
:
qkv_backend
=
"sdpa"
...
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
customized_position_embedding_applier
=
(
customized_position_embedding_applier
)
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
head_dim
=
self
.
head_size
,
num_heads
=
self
.
num_attention_heads_per_partition
,
...
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
assert
x
.
dim
()
==
3
,
x
.
shape
bsz
,
s
,
_
=
x
.
shape
x_shape
=
x
.
shape
bsz
,
s
,
_
=
x_shape
head
=
self
.
num_attention_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
if
self
.
use_qkv_parallel
:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv
,
_
=
self
.
qkv_proj
(
x
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [b, s, embed_dim] --> [b * s, head, head_size]
...
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
]
if
position_embeddings
is
not
None
:
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
if
self
.
customized_position_embedding_applier
is
not
None
:
q
,
k
=
self
.
customized_position_embedding_applier
(
q
,
k
,
position_embeddings
,
x_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
else
:
cos
,
sin
=
position_embeddings
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
if
q
.
dim
()
==
4
:
# [b, s, head, head_size] --> [b * s, head, head_size]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
ea93079b
...
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
...
...
@@ -202,13 +201,29 @@ class TokenizerManager:
if
self
.
model_config
.
is_multimodal
:
import_processors
()
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
try
:
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
server_args
.
tokenizer_path
}
does not have a slow version. Automatically use fast version"
)
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
True
,
)
else
:
raise
e
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
# We want to parallelize the image pre-processing so we create an executor for it
...
...
python/sglang/srt/models/llama4.py
View file @
ea93079b
...
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if
self
.
use_qk_norm
else
None
)
qkv_quant_config
=
quant_config
o_quant_config
=
quant_config
if
quant_config
and
hasattr
(
quant_config
,
"ignore"
)
and
quant_config
.
ignore
:
if
add_prefix
(
"q_proj"
,
prefix
)
in
quant_config
.
ignore
:
qkv_quant_config
=
None
if
add_prefix
(
"o_proj"
,
prefix
)
in
quant_config
.
ignore
:
o_quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
qkv_
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
...
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
quant_config
=
o_
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
...
...
python/sglang/srt/models/mllama4.py
View file @
ea93079b
This diff is collapsed.
Click to expand it.
python/sglang/srt/multimodal/processors/base_processor.py
View file @
ea93079b
...
...
@@ -12,7 +12,6 @@ import torch
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.mm_utils
import
TransportProxyTensor
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
...
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs
[
"audio"
]
=
audios
processor
=
self
.
_processor
if
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
if
(
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
):
kwargs
[
"device"
]
=
"cuda"
result
=
processor
.
__call__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment