Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea93079b
Unverified
Commit
ea93079b
authored
Aug 02, 2025
by
Wenchen Lo
Committed by
GitHub
Aug 02, 2025
Browse files
model: adapt mllama4 to VisionAttention (#8512)
Co-authored-by:
root
<
mickjagger19@icloud.com
>
parent
4bec99ec
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
518 additions
and
52 deletions
+518
-52
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+25
-10
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+27
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+23
-8
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+11
-2
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+428
-19
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+4
-3
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
ea93079b
...
...
@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""
import
contextlib
import
logging
import
os
import
warnings
from
pathlib
import
Path
...
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
)
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
,
lru_cache_frozenset
from
sglang.srt.utils
import
is_remote_url
,
logger
,
lru_cache_frozenset
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
...
@@ -317,15 +316,31 @@ def get_processor(
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
kwargs
[
"use_fast"
]
=
use_fast
try
:
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
tokenizer_name
}
does not have a slow version. Automatically use fast version"
)
kwargs
[
"use_fast"
]
=
True
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
else
:
raise
e
tokenizer
=
get_tokenizer_from_processor
(
processor
)
attach_additional_stop_token_ids
(
tokenizer
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
ea93079b
...
...
@@ -4,7 +4,7 @@ import dataclasses
import
functools
import
math
from
functools
import
lru_cache
,
partial
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
output
=
flash_attn_varlen_func
(
q
,
k
,
...
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias
:
bool
=
True
,
qk_normalization
:
bool
=
False
,
layer_norm_eps
:
float
=
1e-06
,
customized_position_embedding_applier
:
Callable
[
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
=
None
,
**
kwargs
,
):
super
().
__init__
()
...
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
)
# priority: server_args > passed qkv_backend > sdpa
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
qkv_backend
is
None
:
qkv_backend
=
"sdpa"
...
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
customized_position_embedding_applier
=
(
customized_position_embedding_applier
)
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
head_dim
=
self
.
head_size
,
num_heads
=
self
.
num_attention_heads_per_partition
,
...
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
assert
x
.
dim
()
==
3
,
x
.
shape
bsz
,
s
,
_
=
x
.
shape
x_shape
=
x
.
shape
bsz
,
s
,
_
=
x_shape
head
=
self
.
num_attention_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
if
self
.
use_qkv_parallel
:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv
,
_
=
self
.
qkv_proj
(
x
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [b, s, embed_dim] --> [b * s, head, head_size]
...
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
]
if
position_embeddings
is
not
None
:
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
if
self
.
customized_position_embedding_applier
is
not
None
:
q
,
k
=
self
.
customized_position_embedding_applier
(
q
,
k
,
position_embeddings
,
x_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
else
:
cos
,
sin
=
position_embeddings
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
if
q
.
dim
()
==
4
:
# [b, s, head, head_size] --> [b * s, head, head_size]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
ea93079b
...
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchStrOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
...
...
@@ -202,13 +201,29 @@ class TokenizerManager:
if
self
.
model_config
.
is_multimodal
:
import_processors
()
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
try
:
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
server_args
.
tokenizer_path
}
does not have a slow version. Automatically use fast version"
)
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
True
,
)
else
:
raise
e
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
# We want to parallelize the image pre-processing so we create an executor for it
...
...
python/sglang/srt/models/llama4.py
View file @
ea93079b
...
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if
self
.
use_qk_norm
else
None
)
qkv_quant_config
=
quant_config
o_quant_config
=
quant_config
if
quant_config
and
hasattr
(
quant_config
,
"ignore"
)
and
quant_config
.
ignore
:
if
add_prefix
(
"q_proj"
,
prefix
)
in
quant_config
.
ignore
:
qkv_quant_config
=
None
if
add_prefix
(
"o_proj"
,
prefix
)
in
quant_config
.
ignore
:
o_quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
qkv_
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
...
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
quant_config
=
o_
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
...
...
python/sglang/srt/models/mllama4.py
View file @
ea93079b
import
json
as
json_lib
import
logging
import
math
import
os
from
collections.abc
import
Iterable
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
Llama4Config
from
transformers
import
Llama4Config
,
Llama4VisionConfig
from
transformers.models.llama4.modeling_llama4
import
(
Llama4MultiModalProjector
,
Llama4VisionModel
,
vision_apply_rotary_emb
,
)
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
QuantizationConfig
...
...
@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cpu
from
sglang.srt.utils
import
is_cpu
_is_cpu
=
is_cpu
()
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
...
...
@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
logger
=
logging
.
getLogger
(
__name__
)
class
Llama4VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
intermediate_size
:
int
,
output_size
:
int
,
bias
:
bool
,
output_activation
:
bool
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
cls_fc1
=
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
self
.
fc1
=
cls_fc1
(
input_size
=
input_size
,
output_size
=
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
cls_fc2
=
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
self
.
fc2
=
cls_fc2
(
input_size
=
intermediate_size
,
output_size
=
output_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
activation_fn
=
nn
.
GELU
()
self
.
output_activation
=
output_activation
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
if
self
.
output_activation
:
return
self
.
activation_fn
(
hidden_states
)
return
hidden_states
def
pixel_shuffle
(
input_tensor
,
shuffle_ratio
):
# input_tensor: [batch_size, num_patches, channels]
batch_size
,
num_patches
,
channels
=
input_tensor
.
shape
patch_size
=
int
(
math
.
sqrt
(
num_patches
))
input_tensor
=
input_tensor
.
view
(
batch_size
,
patch_size
,
patch_size
,
-
1
)
batch_size
,
height
,
width
,
channels
=
input_tensor
.
size
()
reshaped_tensor
=
input_tensor
.
view
(
batch_size
,
height
,
int
(
width
*
shuffle_ratio
),
int
(
channels
/
shuffle_ratio
)
)
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
reshaped_tensor
=
reshaped_tensor
.
view
(
batch_size
,
int
(
height
*
shuffle_ratio
),
int
(
width
*
shuffle_ratio
),
int
(
channels
/
(
shuffle_ratio
**
2
)),
)
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
output_tensor
=
reshaped_tensor
.
view
(
batch_size
,
-
1
,
reshaped_tensor
.
shape
[
-
1
])
return
output_tensor
class
Llama4VisionPixelShuffleMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
pixel_shuffle_ratio
=
config
.
pixel_shuffle_ratio
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
projector_input_dim
,
output_size
=
config
.
projector_output_dim
,
bias
=
config
.
multi_modal_projector_bias
,
output_activation
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
def
forward
(
self
,
encoded_patches
:
torch
.
Tensor
)
->
torch
.
Tensor
:
encoded_patches
=
pixel_shuffle
(
encoded_patches
,
self
.
pixel_shuffle_ratio
)
return
self
.
mlp
(
encoded_patches
)
def
apply_position_embedding
(
q
,
k
,
freqs_ci
,
shape
):
# [batch_size_times_num_tiles, num_channels]
input_shape
=
shape
[:
2
]
# [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
hidden_shape
=
(
*
input_shape
,
*
q
.
shape
[
-
2
:])
q
=
q
.
view
(
hidden_shape
)
k
=
k
.
view
(
hidden_shape
)
q
,
k
=
vision_apply_rotary_emb
(
q
,
k
,
freqs_ci
)
return
q
,
k
class
Llama4VisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
VisionAttention
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
hidden_size
,
use_qkv_parallel
=
True
,
# vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
quant_config
=
None
,
dropout
=
0.0
,
qkv_backend
=
"sdpa"
,
softmax_in_single_precision
=
False
,
flatten_batch
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
qkv_bias
=
True
,
customized_position_embedding_applier
=
apply_position_embedding
,
)
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
output_size
=
config
.
hidden_size
,
bias
=
True
,
output_activation
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
freqs_ci
:
torch
.
Tensor
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
position_embeddings
=
freqs_ci
)
hidden_state
=
residual
+
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
residual
+
hidden_state
outputs
=
hidden_state
return
outputs
class
Llama4VisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
(
[
Llama4VisionEncoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
freqs_ci
:
torch
.
Tensor
,
# TODO: move this to an attribute instead of keeping it around
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
,
freqs_ci
=
freqs_ci
)
hidden_states
=
layer_outputs
return
hidden_states
class
Llama4UnfoldConvolution
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
kernel_size
=
config
.
patch_size
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
config
.
patch_size
)
params
=
{
"input_size"
:
config
.
num_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
"output_size"
:
config
.
hidden_size
,
"bias"
:
False
,
"quant_config"
:
quant_config
,
"prefix"
:
f
"
{
prefix
}
.linear"
,
}
if
use_data_parallel
:
cls
=
ReplicatedLinear
else
:
cls
=
ColumnParallelLinear
params
[
"gather_output"
]
=
True
self
.
linear
=
cls
(
**
params
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
unfold
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
)
hidden_states
,
_
=
self
.
linear
(
hidden_states
)
return
hidden_states
class
Llama4VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
idx
=
config
.
image_size
//
config
.
patch_size
img_idx
=
torch
.
arange
(
idx
**
2
,
dtype
=
torch
.
int32
).
reshape
(
idx
**
2
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# ID_CLS_TOKEN
frequencies_x
=
img_idx
%
idx
# get the coordinates of the 2d matrix along x
frequencies_y
=
img_idx
//
idx
# get the coordinates of the 2d matrix along y
freq_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
//
2
rope_freq
=
1.0
/
(
config
.
rope_theta
**
(
torch
.
arange
(
0
,
freq_dim
,
2
)[:
(
freq_dim
//
2
)].
float
()
/
freq_dim
)
)
freqs_x
=
(
(
frequencies_x
+
1
)[...,
None
]
*
rope_freq
[
None
,
None
,
:]
).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
(
(
frequencies_y
+
1
)[...,
None
]
*
rope_freq
[
None
,
None
,
:]
).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
freq_cis
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
)
)
self
.
freqs_ci
=
freq_cis
# idx**2, idx**2, idx * 2
def
forward
(
self
,
hidden_states
):
return
self
.
freqs_ci
.
to
(
hidden_states
.
device
)
class
Llama4VisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_channels
=
config
.
num_channels
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
Llama4UnfoldConvolution
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.patch_embedding"
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
positional_embedding_vlm
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
)
self
.
rotary_embedding
=
Llama4VisionRotaryEmbedding
(
config
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
# encoders
self
.
model
=
Llama4VisionEncoder
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.model"
,
)
self
.
vision_adapter
=
Llama4VisionPixelShuffleMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_adapter"
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Patch embedding
hidden_state
=
self
.
patch_embedding
(
pixel_values
)
num_tiles
,
num_patches
,
hidden_dim
=
hidden_state
.
shape
# Add cls token
class_embedding
=
self
.
class_embedding
.
expand
(
hidden_state
.
shape
[
0
],
1
,
hidden_state
.
shape
[
-
1
]
)
hidden_state
=
torch
.
cat
([
hidden_state
,
class_embedding
],
dim
=
1
)
num_patches
+=
1
# Position embeddings
hidden_state
=
hidden_state
.
reshape
(
num_tiles
,
1
,
num_patches
,
hidden_dim
,
)
positional_embedding
=
self
.
positional_embedding_vlm
.
to
(
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
hidden_state
=
hidden_state
+
positional_embedding
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
hidden_state
=
hidden_state
.
view
(
num_tiles
,
-
1
,
hidden_dim
)
freqs_ci
=
self
.
rotary_embedding
(
pixel_values
)
# Apply encoder
hidden_state
=
self
.
model
(
hidden_state
,
freqs_ci
=
freqs_ci
)
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
# Remove CLS token output
hidden_state
=
hidden_state
[:,
:
-
1
,
:]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state
=
self
.
vision_adapter
(
hidden_state
)
return
hidden_state
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
...
...
@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
if
not
self
.
has_vision_weights
:
logger
.
warning
(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
"Multimodal capabilities (vision understanding) will be unavailable. "
"Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
)
self
.
has_vision
=
(
...
...
@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
)
if
self
.
has_vision
:
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
)
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
),
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
else
:
self
.
vision_model
=
None
...
...
@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
filename
=
"model.safetensors.index.json"
,
cache_dir
=
None
,
)
if
index_file_path
and
os
.
path
.
exists
(
index_file_path
):
return
self
.
_check_vision_weights_in_index
(
index_file_path
)
...
...
@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# If we can't access the cache, fall back to config-based detection
pass
# Fallback
,
assume text-only
# Fallback
,
assume text-only
return
False
def
_check_vision_weights_in_index
(
self
,
index_file
:
str
)
->
bool
:
...
...
@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
vision_patterns
=
[
"vision_model"
,
"vision_tower"
,
"multi_modal_projector"
]
weight_names
=
index_data
.
get
(
"weight_map"
,
{}).
keys
()
return
any
(
pattern
in
weight_name
for
weight_name
in
weight_names
...
...
@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
# For text-only models, return None or raise an error
if
not
self
.
has_vision
or
self
.
vision_model
is
None
:
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
pixel_values
=
(
torch
.
concat
([
item
.
feature
for
item
in
items
])
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
.
type
(
next
(
self
.
vision_model
.
parameters
()).
dtype
)
)
image_features
=
self
.
vision_model
(
pixel_values
)
image_outputs
=
self
.
vision_model
(
pixel_values
,
output_hidden_states
=
False
)
image_features
=
image_outputs
.
last_hidden_state
vision_flat
=
image_features
.
view
(
-
1
,
image_features
.
size
(
-
1
))
projected_vision_flat
=
self
.
multi_modal_projector
(
vision_flat
)
return
projected_vision_flat
def
forward
(
...
...
@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
num_experts
=
num_experts
,
)
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
self
.
_should_skip_weight
(
name
):
continue
name
=
self
.
_transform_weight_name
(
name
)
if
"vision"
not
in
name
:
if
"vision"
in
name
:
name
=
name
.
replace
(
".self_attn.o_proj"
,
".self_attn.proj"
)
else
:
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
)
if
self
.
_handle_scale_remapping
(
name
,
params_dict
):
loaded_params
.
add
(
name
)
continue
if
self
.
_handle_stacked_params
(
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
,
loaded_params
):
continue
if
self
.
_handle_expert_weights
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
,
loaded_params
,
):
continue
loaded_params
.
add
(
name
)
self
.
_handle_default_weight
(
name
,
loaded_weight
,
params_dict
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
f
"Some weights are not initialized from checkpoints
{
unloaded_params
}
"
)
def
_should_skip_weight
(
self
,
name
:
str
)
->
bool
:
"""Check if we should skip loading this weight."""
...
...
@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
stacked_params_mapping
:
list
,
params_dict
:
dict
,
loaded_params
:
set
,
)
->
bool
:
"""Handle stacked parameter loading. Returns True if handled."""
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
in
name
and
"vision"
not
in
name
:
if
weight_name
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
loaded_params
.
add
(
transformed_name
)
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
True
...
...
@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
expert_params_mapping
:
list
,
params_dict
:
dict
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
...
...
@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
if
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
:
return
self
.
_handle_other_expert_params
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
loaded_params
)
if
"scale"
in
name
:
return
self
.
_handle_expert_scale_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
name
,
loaded_weight
,
params_dict
,
num_experts
,
loaded_params
)
else
:
return
self
.
_handle_expert_weight_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
name
,
loaded_weight
,
params_dict
,
num_experts
,
loaded_params
)
def
_handle_other_expert_params
(
...
...
@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
loaded_params
:
set
,
)
->
bool
:
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
...
...
@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters
loaded_params: Set of loaded parameter names
Returns:
bool: True if parameter was found and handled, False otherwise
...
...
@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param
.
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
transformed_name
)
return
True
return
False
...
...
@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
"""Handle quantization scale parameters for expert weights.
...
...
@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations
loaded_params: Set of loaded parameter names
Returns:
bool: True (always handles scale parameters)
...
...
@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Load the same scale for all experts
for
expert_id
in
range
(
num_experts
):
param
.
data
[
expert_id
]
=
loaded_weight
loaded_params
.
add
(
transformed_name
)
return
True
...
...
@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
...
...
@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution
loaded_params: Set of loaded parameter names
Returns:
bool: True (always handles weight parameters)
...
...
@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
loaded_params
.
add
(
param_name
)
# Handle the case where loaded_weight might be a single tensor for all experts
if
weight_chunk
.
dim
()
==
2
:
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
ea93079b
...
...
@@ -12,7 +12,6 @@ import torch
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.mm_utils
import
TransportProxyTensor
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
...
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs
[
"audio"
]
=
audios
processor
=
self
.
_processor
if
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
if
(
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
):
kwargs
[
"device"
]
=
"cuda"
result
=
processor
.
__call__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment