Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea93079b
Unverified
Commit
ea93079b
authored
Aug 02, 2025
by
Wenchen Lo
Committed by
GitHub
Aug 02, 2025
Browse files
model: adapt mllama4 to VisionAttention (#8512)
Co-authored-by:
root
<
mickjagger19@icloud.com
>
parent
4bec99ec
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
518 additions
and
52 deletions
+518
-52
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+25
-10
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+27
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+23
-8
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+11
-2
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+428
-19
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+4
-3
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
ea93079b
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
logging
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
)
)
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
,
lru_cache_frozenset
from
sglang.srt.utils
import
is_remote_url
,
logger
,
lru_cache_frozenset
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
@@ -317,15 +316,31 @@ def get_processor(
...
@@ -317,15 +316,31 @@ def get_processor(
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
kwargs
[
"use_fast"
]
=
use_fast
kwargs
[
"use_fast"
]
=
use_fast
try
:
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
processor
=
AutoProcessor
.
from_pretrained
(
except
ValueError
as
e
:
tokenizer_name
,
error_message
=
str
(
e
)
*
args
,
if
"does not have a slow version"
in
error_message
:
trust_remote_code
=
trust_remote_code
,
logger
.
info
(
revision
=
revision
,
f
"Processor
{
tokenizer_name
}
does not have a slow version. Automatically use fast version"
**
kwargs
,
)
)
kwargs
[
"use_fast"
]
=
True
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
else
:
raise
e
tokenizer
=
get_tokenizer_from_processor
(
processor
)
tokenizer
=
get_tokenizer_from_processor
(
processor
)
attach_additional_stop_token_ids
(
tokenizer
)
attach_additional_stop_token_ids
(
tokenizer
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
ea93079b
...
@@ -4,7 +4,7 @@ import dataclasses
...
@@ -4,7 +4,7 @@ import dataclasses
import
functools
import
functools
import
math
import
math
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
max_seqlen
=
seq_lens
.
max
().
item
()
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
,
q
,
k
,
k
,
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias
:
bool
=
True
,
qkv_bias
:
bool
=
True
,
qk_normalization
:
bool
=
False
,
qk_normalization
:
bool
=
False
,
layer_norm_eps
:
float
=
1e-06
,
layer_norm_eps
:
float
=
1e-06
,
customized_position_embedding_applier
:
Callable
[
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
=
None
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
)
)
# priority: server_args > passed qkv_backend > sdpa
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
qkv_backend
is
None
:
if
qkv_backend
is
None
:
qkv_backend
=
"sdpa"
qkv_backend
=
"sdpa"
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
customized_position_embedding_applier
=
(
customized_position_embedding_applier
)
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
head_dim
=
self
.
head_size
,
head_dim
=
self
.
head_size
,
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if
x
.
dim
()
==
2
:
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
x
=
x
.
unsqueeze
(
0
)
assert
x
.
dim
()
==
3
,
x
.
shape
assert
x
.
dim
()
==
3
,
x
.
shape
bsz
,
s
,
_
=
x
.
shape
x_shape
=
x
.
shape
bsz
,
s
,
_
=
x_shape
head
=
self
.
num_attention_heads_per_partition
head
=
self
.
num_attention_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
if
self
.
use_qkv_parallel
:
if
self
.
use_qkv_parallel
:
# [b, s, embed_dim] --> [b, s, embed_dim]
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv
,
_
=
self
.
qkv_proj
(
x
)
qkv
,
_
=
self
.
qkv_proj
(
x
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [b, s, embed_dim] --> [b * s, head, head_size]
# [b, s, embed_dim] --> [b * s, head, head_size]
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
]
]
if
position_embeddings
is
not
None
:
if
position_embeddings
is
not
None
:
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
original_shape
=
q
.
shape
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
if
self
.
customized_position_embedding_applier
is
not
None
:
q
,
k
=
self
.
customized_position_embedding_applier
(
q
,
k
,
position_embeddings
,
x_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
else
:
cos
,
sin
=
position_embeddings
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
view
(
original_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
if
q
.
dim
()
==
4
:
if
q
.
dim
()
==
4
:
# [b, s, head, head_size] --> [b * s, head, head_size]
# [b, s, head, head_size] --> [b * s, head, head_size]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
ea93079b
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -202,13 +201,29 @@ class TokenizerManager:
...
@@ -202,13 +201,29 @@ class TokenizerManager:
if
self
.
model_config
.
is_multimodal
:
if
self
.
model_config
.
is_multimodal
:
import_processors
()
import_processors
()
_processor
=
get_processor
(
try
:
server_args
.
tokenizer_path
,
_processor
=
get_processor
(
tokenizer_mode
=
server_args
.
tokenizer_mode
,
server_args
.
tokenizer_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
revision
=
server_args
.
revision
,
trust_remote_code
=
server_args
.
trust_remote_code
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
revision
=
server_args
.
revision
,
)
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
server_args
.
tokenizer_path
}
does not have a slow version. Automatically use fast version"
)
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
True
,
)
else
:
raise
e
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
# We want to parallelize the image pre-processing so we create an executor for it
# We want to parallelize the image pre-processing so we create an executor for it
...
...
python/sglang/srt/models/llama4.py
View file @
ea93079b
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if
self
.
use_qk_norm
if
self
.
use_qk_norm
else
None
else
None
)
)
qkv_quant_config
=
quant_config
o_quant_config
=
quant_config
if
quant_config
and
hasattr
(
quant_config
,
"ignore"
)
and
quant_config
.
ignore
:
if
add_prefix
(
"q_proj"
,
prefix
)
in
quant_config
.
ignore
:
qkv_quant_config
=
None
if
add_prefix
(
"o_proj"
,
prefix
)
in
quant_config
.
ignore
:
o_quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
qkv_
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
quant_config
=
o_
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
...
...
python/sglang/srt/models/mllama4.py
View file @
ea93079b
This diff is collapsed.
Click to expand it.
python/sglang/srt/multimodal/processors/base_processor.py
View file @
ea93079b
...
@@ -12,7 +12,6 @@ import torch
...
@@ -12,7 +12,6 @@ import torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.mm_utils
import
TransportProxyTensor
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs
[
"audio"
]
=
audios
kwargs
[
"audio"
]
=
audios
processor
=
self
.
_processor
processor
=
self
.
_processor
if
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
if
(
processor
.
image_processor
,
BaseImageProcessorFast
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
):
):
kwargs
[
"device"
]
=
"cuda"
kwargs
[
"device"
]
=
"cuda"
result
=
processor
.
__call__
(
result
=
processor
.
__call__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment