Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea93079b
Unverified
Commit
ea93079b
authored
Aug 02, 2025
by
Wenchen Lo
Committed by
GitHub
Aug 02, 2025
Browse files
model: adapt mllama4 to VisionAttention (#8512)
Co-authored-by:
root
<
mickjagger19@icloud.com
>
parent
4bec99ec
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
518 additions
and
52 deletions
+518
-52
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+25
-10
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+27
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+23
-8
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+11
-2
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+428
-19
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+4
-3
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
ea93079b
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""
"""Utilities for Huggingface Transformers."""
import
contextlib
import
contextlib
import
logging
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
...
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
)
)
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.configs.internvl
import
InternVLChatConfig
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.connector
import
create_remote_connector
from
sglang.srt.utils
import
is_remote_url
,
lru_cache_frozenset
from
sglang.srt.utils
import
is_remote_url
,
logger
,
lru_cache_frozenset
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
...
@@ -317,15 +316,31 @@ def get_processor(
...
@@ -317,15 +316,31 @@ def get_processor(
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
if
config
.
model_type
not
in
{
"llava"
,
"clip"
}:
kwargs
[
"use_fast"
]
=
use_fast
kwargs
[
"use_fast"
]
=
use_fast
try
:
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
processor
=
AutoProcessor
.
from_pretrained
(
except
ValueError
as
e
:
tokenizer_name
,
error_message
=
str
(
e
)
*
args
,
if
"does not have a slow version"
in
error_message
:
trust_remote_code
=
trust_remote_code
,
logger
.
info
(
revision
=
revision
,
f
"Processor
{
tokenizer_name
}
does not have a slow version. Automatically use fast version"
**
kwargs
,
)
)
kwargs
[
"use_fast"
]
=
True
processor
=
AutoProcessor
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
else
:
raise
e
tokenizer
=
get_tokenizer_from_processor
(
processor
)
tokenizer
=
get_tokenizer_from_processor
(
processor
)
attach_additional_stop_token_ids
(
tokenizer
)
attach_additional_stop_token_ids
(
tokenizer
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
ea93079b
...
@@ -4,7 +4,7 @@ import dataclasses
...
@@ -4,7 +4,7 @@ import dataclasses
import
functools
import
functools
import
math
import
math
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
...
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
max_seqlen
=
seq_lens
.
max
().
item
()
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
,
q
,
k
,
k
,
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
...
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
qkv_bias
:
bool
=
True
,
qkv_bias
:
bool
=
True
,
qk_normalization
:
bool
=
False
,
qk_normalization
:
bool
=
False
,
layer_norm_eps
:
float
=
1e-06
,
layer_norm_eps
:
float
=
1e-06
,
customized_position_embedding_applier
:
Callable
[
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
=
None
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
...
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
)
)
# priority: server_args > passed qkv_backend > sdpa
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
qkv_backend
is
None
:
if
qkv_backend
is
None
:
qkv_backend
=
"sdpa"
qkv_backend
=
"sdpa"
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
...
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
customized_position_embedding_applier
=
(
customized_position_embedding_applier
)
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
head_dim
=
self
.
head_size
,
head_dim
=
self
.
head_size
,
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
...
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
if
x
.
dim
()
==
2
:
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
x
=
x
.
unsqueeze
(
0
)
assert
x
.
dim
()
==
3
,
x
.
shape
assert
x
.
dim
()
==
3
,
x
.
shape
bsz
,
s
,
_
=
x
.
shape
x_shape
=
x
.
shape
bsz
,
s
,
_
=
x_shape
head
=
self
.
num_attention_heads_per_partition
head
=
self
.
num_attention_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
kv_head
=
self
.
num_attention_kv_heads_per_partition
if
self
.
use_qkv_parallel
:
if
self
.
use_qkv_parallel
:
# [b, s, embed_dim] --> [b, s, embed_dim]
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv
,
_
=
self
.
qkv_proj
(
x
)
qkv
,
_
=
self
.
qkv_proj
(
x
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [b, s, embed_dim] --> [b * s, head, head_size]
# [b, s, embed_dim] --> [b * s, head, head_size]
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
...
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
]
]
if
position_embeddings
is
not
None
:
if
position_embeddings
is
not
None
:
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
original_shape
=
q
.
shape
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
if
self
.
customized_position_embedding_applier
is
not
None
:
q
,
k
=
self
.
customized_position_embedding_applier
(
q
,
k
,
position_embeddings
,
x_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
else
:
cos
,
sin
=
position_embeddings
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
=
q
.
view
(
original_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
if
q
.
dim
()
==
4
:
if
q
.
dim
()
==
4
:
# [b, s, head, head_size] --> [b * s, head, head_size]
# [b, s, head, head_size] --> [b * s, head, head_size]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
ea93079b
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut
,
BatchMultimodalOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
BlockReqType
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -202,13 +201,29 @@ class TokenizerManager:
...
@@ -202,13 +201,29 @@ class TokenizerManager:
if
self
.
model_config
.
is_multimodal
:
if
self
.
model_config
.
is_multimodal
:
import_processors
()
import_processors
()
_processor
=
get_processor
(
try
:
server_args
.
tokenizer_path
,
_processor
=
get_processor
(
tokenizer_mode
=
server_args
.
tokenizer_mode
,
server_args
.
tokenizer_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
revision
=
server_args
.
revision
,
trust_remote_code
=
server_args
.
trust_remote_code
,
use_fast
=
not
server_args
.
disable_fast_image_processor
,
revision
=
server_args
.
revision
,
)
use_fast
=
not
server_args
.
disable_fast_image_processor
,
)
except
ValueError
as
e
:
error_message
=
str
(
e
)
if
"does not have a slow version"
in
error_message
:
logger
.
info
(
f
"Processor
{
server_args
.
tokenizer_path
}
does not have a slow version. Automatically use fast version"
)
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
use_fast
=
True
,
)
else
:
raise
e
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
transport_mode
=
_determine_tensor_transport_mode
(
self
.
server_args
)
# We want to parallelize the image pre-processing so we create an executor for it
# We want to parallelize the image pre-processing so we create an executor for it
...
...
python/sglang/srt/models/llama4.py
View file @
ea93079b
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
...
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
if
self
.
use_qk_norm
if
self
.
use_qk_norm
else
None
else
None
)
)
qkv_quant_config
=
quant_config
o_quant_config
=
quant_config
if
quant_config
and
hasattr
(
quant_config
,
"ignore"
)
and
quant_config
.
ignore
:
if
add_prefix
(
"q_proj"
,
prefix
)
in
quant_config
.
ignore
:
qkv_quant_config
=
None
if
add_prefix
(
"o_proj"
,
prefix
)
in
quant_config
.
ignore
:
o_quant_config
=
None
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
qkv_
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
...
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
quant_config
=
o_
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
...
...
python/sglang/srt/models/mllama4.py
View file @
ea93079b
import
json
as
json_lib
import
json
as
json_lib
import
logging
import
logging
import
math
import
os
import
os
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Llama4Config
from
transformers
import
Llama4Config
,
Llama4VisionConfig
from
transformers.models.llama4.modeling_llama4
import
(
from
transformers.models.llama4.modeling_llama4
import
(
Llama4MultiModalProjector
,
Llama4MultiModalProjector
,
Llama4VisionModel
,
vision_apply_rotary_emb
,
)
)
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
QuantizationConfig
...
@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
is_cpu
from
sglang.srt.utils
import
add_prefix
,
is_cpu
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
...
@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
...
@@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
Llama4VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
intermediate_size
:
int
,
output_size
:
int
,
bias
:
bool
,
output_activation
:
bool
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
cls_fc1
=
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
self
.
fc1
=
cls_fc1
(
input_size
=
input_size
,
output_size
=
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
cls_fc2
=
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
self
.
fc2
=
cls_fc2
(
input_size
=
intermediate_size
,
output_size
=
output_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
activation_fn
=
nn
.
GELU
()
self
.
output_activation
=
output_activation
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
if
self
.
output_activation
:
return
self
.
activation_fn
(
hidden_states
)
return
hidden_states
def
pixel_shuffle
(
input_tensor
,
shuffle_ratio
):
# input_tensor: [batch_size, num_patches, channels]
batch_size
,
num_patches
,
channels
=
input_tensor
.
shape
patch_size
=
int
(
math
.
sqrt
(
num_patches
))
input_tensor
=
input_tensor
.
view
(
batch_size
,
patch_size
,
patch_size
,
-
1
)
batch_size
,
height
,
width
,
channels
=
input_tensor
.
size
()
reshaped_tensor
=
input_tensor
.
view
(
batch_size
,
height
,
int
(
width
*
shuffle_ratio
),
int
(
channels
/
shuffle_ratio
)
)
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
reshaped_tensor
=
reshaped_tensor
.
view
(
batch_size
,
int
(
height
*
shuffle_ratio
),
int
(
width
*
shuffle_ratio
),
int
(
channels
/
(
shuffle_ratio
**
2
)),
)
reshaped_tensor
=
reshaped_tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
output_tensor
=
reshaped_tensor
.
view
(
batch_size
,
-
1
,
reshaped_tensor
.
shape
[
-
1
])
return
output_tensor
class
Llama4VisionPixelShuffleMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
pixel_shuffle_ratio
=
config
.
pixel_shuffle_ratio
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
projector_input_dim
,
output_size
=
config
.
projector_output_dim
,
bias
=
config
.
multi_modal_projector_bias
,
output_activation
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
def
forward
(
self
,
encoded_patches
:
torch
.
Tensor
)
->
torch
.
Tensor
:
encoded_patches
=
pixel_shuffle
(
encoded_patches
,
self
.
pixel_shuffle_ratio
)
return
self
.
mlp
(
encoded_patches
)
def
apply_position_embedding
(
q
,
k
,
freqs_ci
,
shape
):
# [batch_size_times_num_tiles, num_channels]
input_shape
=
shape
[:
2
]
# [batch_size_times_num_tiles, num_channels, num_heads, head_dim]
hidden_shape
=
(
*
input_shape
,
*
q
.
shape
[
-
2
:])
q
=
q
.
view
(
hidden_shape
)
k
=
k
.
view
(
hidden_shape
)
q
,
k
=
vision_apply_rotary_emb
(
q
,
k
,
freqs_ci
)
return
q
,
k
class
Llama4VisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
VisionAttention
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
hidden_size
,
use_qkv_parallel
=
True
,
# vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8
quant_config
=
None
,
dropout
=
0.0
,
qkv_backend
=
"sdpa"
,
softmax_in_single_precision
=
False
,
flatten_batch
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
qkv_bias
=
True
,
customized_position_embedding_applier
=
apply_position_embedding
,
)
self
.
mlp
=
Llama4VisionMLP
(
input_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
output_size
=
config
.
hidden_size
,
bias
=
True
,
output_activation
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
freqs_ci
:
torch
.
Tensor
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
position_embeddings
=
freqs_ci
)
hidden_state
=
residual
+
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
hidden_state
=
residual
+
hidden_state
outputs
=
hidden_state
return
outputs
class
Llama4VisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
(
[
Llama4VisionEncoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
freqs_ci
:
torch
.
Tensor
,
# TODO: move this to an attribute instead of keeping it around
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
,
freqs_ci
=
freqs_ci
)
hidden_states
=
layer_outputs
return
hidden_states
class
Llama4UnfoldConvolution
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
kernel_size
=
config
.
patch_size
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
config
.
patch_size
)
params
=
{
"input_size"
:
config
.
num_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
"output_size"
:
config
.
hidden_size
,
"bias"
:
False
,
"quant_config"
:
quant_config
,
"prefix"
:
f
"
{
prefix
}
.linear"
,
}
if
use_data_parallel
:
cls
=
ReplicatedLinear
else
:
cls
=
ColumnParallelLinear
params
[
"gather_output"
]
=
True
self
.
linear
=
cls
(
**
params
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
unfold
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
)
hidden_states
,
_
=
self
.
linear
(
hidden_states
)
return
hidden_states
class
Llama4VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
idx
=
config
.
image_size
//
config
.
patch_size
img_idx
=
torch
.
arange
(
idx
**
2
,
dtype
=
torch
.
int32
).
reshape
(
idx
**
2
,
1
)
img_idx
=
torch
.
cat
([
img_idx
,
img_idx
[:
1
]],
dim
=
0
)
img_idx
[
-
1
,
-
1
]
=
-
2
# ID_CLS_TOKEN
frequencies_x
=
img_idx
%
idx
# get the coordinates of the 2d matrix along x
frequencies_y
=
img_idx
//
idx
# get the coordinates of the 2d matrix along y
freq_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
//
2
rope_freq
=
1.0
/
(
config
.
rope_theta
**
(
torch
.
arange
(
0
,
freq_dim
,
2
)[:
(
freq_dim
//
2
)].
float
()
/
freq_dim
)
)
freqs_x
=
(
(
frequencies_x
+
1
)[...,
None
]
*
rope_freq
[
None
,
None
,
:]
).
repeat_interleave
(
2
,
dim
=-
1
)
freqs_y
=
(
(
frequencies_y
+
1
)[...,
None
]
*
rope_freq
[
None
,
None
,
:]
).
repeat_interleave
(
2
,
dim
=-
1
)
freqs
=
torch
.
cat
([
freqs_x
,
freqs_y
],
dim
=-
1
).
float
().
contiguous
()[...,
::
2
]
freqs
=
freqs
.
masked_fill
(
img_idx
.
reshape
(
-
1
,
1
,
1
)
<
0
,
0
)
freq_cis
=
torch
.
view_as_complex
(
torch
.
stack
([
torch
.
cos
(
freqs
),
torch
.
sin
(
freqs
)],
dim
=-
1
)
)
self
.
freqs_ci
=
freq_cis
# idx**2, idx**2, idx * 2
def
forward
(
self
,
hidden_states
):
return
self
.
freqs_ci
.
to
(
hidden_states
.
device
)
class
Llama4VisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Llama4VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
hidden_size
=
config
.
hidden_size
self
.
num_channels
=
config
.
num_channels
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
Llama4UnfoldConvolution
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.patch_embedding"
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
positional_embedding_vlm
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
)
self
.
rotary_embedding
=
Llama4VisionRotaryEmbedding
(
config
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-5
)
# encoders
self
.
model
=
Llama4VisionEncoder
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.model"
,
)
self
.
vision_adapter
=
Llama4VisionPixelShuffleMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_adapter"
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Patch embedding
hidden_state
=
self
.
patch_embedding
(
pixel_values
)
num_tiles
,
num_patches
,
hidden_dim
=
hidden_state
.
shape
# Add cls token
class_embedding
=
self
.
class_embedding
.
expand
(
hidden_state
.
shape
[
0
],
1
,
hidden_state
.
shape
[
-
1
]
)
hidden_state
=
torch
.
cat
([
hidden_state
,
class_embedding
],
dim
=
1
)
num_patches
+=
1
# Position embeddings
hidden_state
=
hidden_state
.
reshape
(
num_tiles
,
1
,
num_patches
,
hidden_dim
,
)
positional_embedding
=
self
.
positional_embedding_vlm
.
to
(
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
hidden_state
=
hidden_state
+
positional_embedding
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
hidden_state
=
hidden_state
.
view
(
num_tiles
,
-
1
,
hidden_dim
)
freqs_ci
=
self
.
rotary_embedding
(
pixel_values
)
# Apply encoder
hidden_state
=
self
.
model
(
hidden_state
,
freqs_ci
=
freqs_ci
)
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
# Remove CLS token output
hidden_state
=
hidden_state
[:,
:
-
1
,
:]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state
=
self
.
vision_adapter
(
hidden_state
)
return
hidden_state
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
...
@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module):
if
not
self
.
has_vision_weights
:
if
not
self
.
has_vision_weights
:
logger
.
warning
(
logger
.
warning
(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
"Multimodal capabilities (vision understanding) will be unavailable. "
"Please not that this warning might be inaccurate if the weights haven't been fully downloaded"
)
)
self
.
has_vision
=
(
self
.
has_vision
=
(
...
@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module):
)
)
if
self
.
has_vision
:
if
self
.
has_vision
:
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
)
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
),
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
else
:
else
:
self
.
vision_model
=
None
self
.
vision_model
=
None
...
@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module):
filename
=
"model.safetensors.index.json"
,
filename
=
"model.safetensors.index.json"
,
cache_dir
=
None
,
cache_dir
=
None
,
)
)
if
index_file_path
and
os
.
path
.
exists
(
index_file_path
):
if
index_file_path
and
os
.
path
.
exists
(
index_file_path
):
return
self
.
_check_vision_weights_in_index
(
index_file_path
)
return
self
.
_check_vision_weights_in_index
(
index_file_path
)
...
@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# If we can't access the cache, fall back to config-based detection
# If we can't access the cache, fall back to config-based detection
pass
pass
# Fallback
,
assume text-only
# Fallback
,
assume text-only
return
False
return
False
def
_check_vision_weights_in_index
(
self
,
index_file
:
str
)
->
bool
:
def
_check_vision_weights_in_index
(
self
,
index_file
:
str
)
->
bool
:
...
@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module):
vision_patterns
=
[
"vision_model"
,
"vision_tower"
,
"multi_modal_projector"
]
vision_patterns
=
[
"vision_model"
,
"vision_tower"
,
"multi_modal_projector"
]
weight_names
=
index_data
.
get
(
"weight_map"
,
{}).
keys
()
weight_names
=
index_data
.
get
(
"weight_map"
,
{}).
keys
()
return
any
(
return
any
(
pattern
in
weight_name
pattern
in
weight_name
for
weight_name
in
weight_names
for
weight_name
in
weight_names
...
@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module):
# For text-only models, return None or raise an error
# For text-only models, return None or raise an error
if
not
self
.
has_vision
or
self
.
vision_model
is
None
:
if
not
self
.
has_vision
or
self
.
vision_model
is
None
:
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
pixel_values
=
(
pixel_values
=
(
torch
.
concat
([
item
.
feature
for
item
in
items
])
torch
.
concat
([
item
.
feature
for
item
in
items
])
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
.
type
(
next
(
self
.
vision_model
.
parameters
()).
dtype
)
.
type
(
next
(
self
.
vision_model
.
parameters
()).
dtype
)
)
)
image_features
=
self
.
vision_model
(
pixel_values
)
image_outputs
=
self
.
vision_model
(
pixel_values
,
output_hidden_states
=
False
)
image_features
=
image_outputs
.
last_hidden_state
vision_flat
=
image_features
.
view
(
-
1
,
image_features
.
size
(
-
1
))
vision_flat
=
image_features
.
view
(
-
1
,
image_features
.
size
(
-
1
))
projected_vision_flat
=
self
.
multi_modal_projector
(
vision_flat
)
projected_vision_flat
=
self
.
multi_modal_projector
(
vision_flat
)
return
projected_vision_flat
return
projected_vision_flat
def
forward
(
def
forward
(
...
@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module):
num_experts
=
num_experts
,
num_experts
=
num_experts
,
)
)
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
self
.
_should_skip_weight
(
name
):
if
self
.
_should_skip_weight
(
name
):
continue
continue
name
=
self
.
_transform_weight_name
(
name
)
name
=
self
.
_transform_weight_name
(
name
)
if
"vision"
not
in
name
:
if
"vision"
in
name
:
name
=
name
.
replace
(
".self_attn.o_proj"
,
".self_attn.proj"
)
else
:
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
name
,
loaded_weight
)
)
if
self
.
_handle_scale_remapping
(
name
,
params_dict
):
if
self
.
_handle_scale_remapping
(
name
,
params_dict
):
loaded_params
.
add
(
name
)
continue
continue
if
self
.
_handle_stacked_params
(
if
self
.
_handle_stacked_params
(
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
,
loaded_params
):
):
continue
continue
if
self
.
_handle_expert_weights
(
if
self
.
_handle_expert_weights
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
,
loaded_params
,
):
):
continue
continue
loaded_params
.
add
(
name
)
self
.
_handle_default_weight
(
name
,
loaded_weight
,
params_dict
)
self
.
_handle_default_weight
(
name
,
loaded_weight
,
params_dict
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
f
"Some weights are not initialized from checkpoints
{
unloaded_params
}
"
)
def
_should_skip_weight
(
self
,
name
:
str
)
->
bool
:
def
_should_skip_weight
(
self
,
name
:
str
)
->
bool
:
"""Check if we should skip loading this weight."""
"""Check if we should skip loading this weight."""
...
@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
stacked_params_mapping
:
list
,
stacked_params_mapping
:
list
,
params_dict
:
dict
,
params_dict
:
dict
,
loaded_params
:
set
,
)
->
bool
:
)
->
bool
:
"""Handle stacked parameter loading. Returns True if handled."""
"""Handle stacked parameter loading. Returns True if handled."""
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
in
name
and
"vision"
not
in
name
:
if
weight_name
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
loaded_params
.
add
(
transformed_name
)
param
=
params_dict
[
transformed_name
]
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
True
return
True
...
@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module):
expert_params_mapping
:
list
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
params_dict
:
dict
,
num_experts
:
int
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
)
->
bool
:
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
...
@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module):
if
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
:
if
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
:
return
self
.
_handle_other_expert_params
(
return
self
.
_handle_other_expert_params
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
loaded_params
)
)
if
"scale"
in
name
:
if
"scale"
in
name
:
return
self
.
_handle_expert_scale_params
(
return
self
.
_handle_expert_scale_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
name
,
loaded_weight
,
params_dict
,
num_experts
,
loaded_params
)
)
else
:
else
:
return
self
.
_handle_expert_weight_params
(
return
self
.
_handle_expert_weight_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
name
,
loaded_weight
,
params_dict
,
num_experts
,
loaded_params
)
)
def
_handle_other_expert_params
(
def
_handle_other_expert_params
(
...
@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
params_dict
:
dict
,
loaded_params
:
set
,
)
->
bool
:
)
->
bool
:
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
...
@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: The weight tensor to be loaded
loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters
params_dict: Dictionary of model parameters
loaded_params: Set of loaded parameter names
Returns:
Returns:
bool: True if parameter was found and handled, False otherwise
bool: True if parameter was found and handled, False otherwise
...
@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param
.
weight_loader
(
param
.
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
)
loaded_params
.
add
(
transformed_name
)
return
True
return
True
return
False
return
False
...
@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
params_dict
:
dict
,
num_experts
:
int
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
)
->
bool
:
"""Handle quantization scale parameters for expert weights.
"""Handle quantization scale parameters for expert weights.
...
@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Scale tensor to be loaded
loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters
params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations
num_experts: Total number of experts for broadcast operations
loaded_params: Set of loaded parameter names
Returns:
Returns:
bool: True (always handles scale parameters)
bool: True (always handles scale parameters)
...
@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Load the same scale for all experts
# Load the same scale for all experts
for
expert_id
in
range
(
num_experts
):
for
expert_id
in
range
(
num_experts
):
param
.
data
[
expert_id
]
=
loaded_weight
param
.
data
[
expert_id
]
=
loaded_weight
loaded_params
.
add
(
transformed_name
)
return
True
return
True
...
@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
params_dict
:
dict
,
num_experts
:
int
,
num_experts
:
int
,
loaded_params
:
set
,
)
->
bool
:
)
->
bool
:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
...
@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module):
loaded_weight: Weight tensor(s) to be loaded
loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters
params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution
num_experts: Total number of experts for tensor distribution
loaded_params: Set of loaded parameter names
Returns:
Returns:
bool: True (always handles weight parameters)
bool: True (always handles weight parameters)
...
@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module):
param
=
params_dict
[
param_name
]
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
loaded_params
.
add
(
param_name
)
# Handle the case where loaded_weight might be a single tensor for all experts
# Handle the case where loaded_weight might be a single tensor for all experts
if
weight_chunk
.
dim
()
==
2
:
if
weight_chunk
.
dim
()
==
2
:
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
ea93079b
...
@@ -12,7 +12,6 @@ import torch
...
@@ -12,7 +12,6 @@ import torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.mm_utils
import
TransportProxyTensor
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
...
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
kwargs
[
"audio"
]
=
audios
kwargs
[
"audio"
]
=
audios
processor
=
self
.
_processor
processor
=
self
.
_processor
if
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
if
(
processor
.
image_processor
,
BaseImageProcessorFast
hasattr
(
processor
,
"image_processor"
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
):
):
kwargs
[
"device"
]
=
"cuda"
kwargs
[
"device"
]
=
"cuda"
result
=
processor
.
__call__
(
result
=
processor
.
__call__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment