Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32d2b406
Unverified
Commit
32d2b406
authored
Aug 23, 2025
by
Isotr0py
Committed by
GitHub
Aug 22, 2025
Browse files
[Model] Add Ovis2.5 PP support (#23405)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
22cf679a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
105 deletions
+185
-105
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+1
-0
tests/models/multimodal/generation/test_common.py
tests/models/multimodal/generation/test_common.py
+1
-5
tests/models/registry.py
tests/models/registry.py
+1
-3
vllm/model_executor/models/ovis2_5.py
vllm/model_executor/models/ovis2_5.py
+20
-16
vllm/model_executor/models/siglip2navit.py
vllm/model_executor/models/siglip2navit.py
+162
-81
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
32d2b406
...
...
@@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
"openbmb/MiniCPM-Llama3-V-2_5"
:
PPTestSettings
.
fast
(),
"allenai/Molmo-7B-D-0924"
:
PPTestSettings
.
fast
(),
"AIDC-AI/Ovis2-1B"
:
PPTestSettings
.
fast
(),
"AIDC-AI/Ovis2.5-2B"
:
PPTestSettings
.
fast
(),
"microsoft/Phi-3.5-vision-instruct"
:
PPTestSettings
.
fast
(),
"mistralai/Pixtral-12B-2409"
:
PPTestSettings
.
fast
(
load_format
=
"dummy"
),
"Qwen/Qwen-VL-Chat"
:
PPTestSettings
.
fast
(),
...
...
tests/models/multimodal/generation/test_common.py
View file @
32d2b406
...
...
@@ -11,7 +11,6 @@ from pathlib import PosixPath
import
pytest
from
transformers
import
(
AutoModel
,
AutoModelForImageTextToText
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
)
from
transformers.utils
import
is_flash_attn_2_available
from
vllm.platforms
import
current_platform
from
vllm.utils
import
identity
...
...
@@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
dtype
=
"half"
,
num_logprobs
=
10
,
patch_hf_runner
=
model_utils
.
ovis2_5_patch_hf_runner
,
marks
=
[
pytest
.
mark
.
skipif
(
not
is_flash_attn_2_available
(),
reason
=
"HF model needs `flash_attn` installed"
)],
hf_model_kwargs
=
{
"revision"
:
"refs/pr/5"
},
),
"phi3v"
:
VLMTestInfo
(
models
=
[
"microsoft/Phi-3.5-vision-instruct"
],
...
...
tests/models/registry.py
View file @
32d2b406
...
...
@@ -468,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras
=
{
"1.6-llama"
:
"AIDC-AI/Ovis1.6-Llama3.2-3B"
,
"1.6-gemma"
:
"AIDC-AI/Ovis1.6-Gemma2-9B"
}),
# noqa: E501
"Ovis2_5"
:
_HfExamplesInfo
(
"AIDC-AI/Ovis2.5-2B"
,
trust_remote_code
=
True
,
max_transformers_version
=
"4.53"
,
transformers_version_reason
=
"HF model is not compatible"
),
# noqa: E501
trust_remote_code
=
True
),
"PaliGemmaForConditionalGeneration"
:
_HfExamplesInfo
(
"google/paligemma-3b-mix-224"
,
# noqa: E501
extras
=
{
"v2"
:
"google/paligemma2-3b-ft-docci-448"
}),
# noqa: E501
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-3-vision-128k-instruct"
,
...
...
vllm/model_executor/models/ovis2_5.py
View file @
32d2b406
...
...
@@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.processors.ovis2_5
import
Ovis2_5Processor
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
IMAGE_TOKEN
=
"<image>"
VIDEO_TOKEN
=
"<video>"
...
...
@@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
visual_vocab_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vit"
,
use_data_parallel
=
use_data_parallel
,
)
# reserved tokens for INDICATOR_IDS
head_dim
=
visual_vocab_size
-
len
(
INDICATOR_IDS
)
...
...
@@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
model_type
=
config
.
model_type
if
model_type
==
"siglip2_navit"
:
return
Siglip2NavitModel
(
config
=
config
,
)
return
Siglip2NavitModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
use_data_parallel
=
use_data_parallel
)
raise
ValueError
(
f
"Unsupported visual tokenizer model_type:
{
model_type
}
"
)
@
property
def
dtype
(
self
):
def
dtype
(
self
)
->
torch
.
dtype
:
return
next
(
self
.
head
.
parameters
()).
dtype
@
property
def
device
(
self
):
def
device
(
self
)
->
torch
.
device
:
return
next
(
self
.
head
.
parameters
()).
device
def
tokenize
(
self
,
logits
)
:
def
tokenize
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokens
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
logits
.
dtype
)
return
tokens
def
encode
(
self
,
pixel_values
,
grid_thws
):
features
=
self
.
vit
(
pixel_values
,
grid_thws
,
output_hidden_states
=
True
,
return_dict
=
True
)
def
encode
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
)
->
torch
.
Tensor
:
features
=
self
.
vit
(
pixel_values
,
grid_thws
)
# refer to qwen2.5-vl patchmerger
seq_len
,
_
=
features
.
shape
features
=
features
.
reshape
(
seq_len
//
(
self
.
config
.
hidden_stride
**
2
),
...
...
@@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
return
features
def
forward
(
self
,
pixel_values
,
grid_thws
)
->
torch
.
Tensor
:
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
)
->
torch
.
Tensor
:
features
=
self
.
encode
(
pixel_values
,
grid_thws
)
logits
=
self
.
head
(
features
)
tokens
=
self
.
tokenize
(
logits
)
...
...
@@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
@
MULTIMODAL_REGISTRY
.
register_processor
(
Ovis2_5MultiModalProcessor
,
info
=
Ovis2_5ProcessingInfo
,
dummy_inputs
=
Ovis2_5DummyInputsBuilder
)
class
Ovis2_5
(
nn
.
Module
,
SupportsMultiModal
):
class
Ovis2_5
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
text_model_type
=
self
.
config
.
get_text_config
().
model_type
self
.
image_pad_token_id
=
IMAGE_PAD_TOKEN_ID_MAP
[
text_model_type
]
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
self
.
make_empty_intermediate_tensors
=
(
self
.
get_language_model
().
make_empty_intermediate_tensors
)
def
_parse_and_validate_visual_input
(
self
,
is_video
,
...
...
@@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
return
loader
.
load_weights
(
weights
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
llm
\ No newline at end of file
return
self
.
llm
vllm/model_executor/models/siglip2navit.py
View file @
32d2b406
...
...
@@ -3,16 +3,24 @@
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
from
typing
import
Optional
,
Union
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
from
einops
import
rearrange
,
repeat
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
transformers
.activations
import
ACT2FN
from
transformers
import
Siglip2VisionConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_outputs
import
BaseModelOutputWithNoAttention
from
vllm.config
import
QuantizationConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.platforms
import
_Backend
from
.vision
import
get_vit_attn_backend
...
...
@@ -48,10 +56,11 @@ class Siglip2VisionEmbeddings(nn.Module):
# siglip2 naflex
if
self
.
num_patches
>
0
:
self
.
patch_embedding
=
nn
.
Linear
(
in
_features
=
config
.
num_channels
*
self
.
patch_size
*
self
.
patch_embedding
=
Replicated
Linear
(
in
put_size
=
config
.
num_channels
*
self
.
patch_size
*
self
.
patch_size
,
out_features
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
return_bias
=
False
,
)
if
self
.
preserve_original_pe
:
self
.
position_embedding_size
=
int
(
self
.
num_patches
**
0.5
)
...
...
@@ -89,7 +98,7 @@ class Siglip2VisionEmbeddings(nn.Module):
# Apply patch embeddings to already patchified pixel values
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
if
isinstance
(
self
.
patch_embedding
,
nn
.
Linear
):
if
isinstance
(
self
.
patch_embedding
,
Linear
Base
):
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
elif
isinstance
(
self
.
patch_embedding
,
nn
.
Conv2d
):
...
...
@@ -184,7 +193,13 @@ def apply_rotary_pos_emb(
class
Siglip2Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -199,11 +214,25 @@ class Siglip2Attention(nn.Module):
self
.
dropout
=
config
.
attention_dropout
self
.
is_causal
=
False
self
.
k_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
v_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
q_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
tp_size
=
(
1
if
use_data_parallel
else
get_tensor_model_parallel_world_size
())
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
use_rope
=
config
.
use_rope
# Detect attention implementation.
...
...
@@ -228,13 +257,15 @@ class Siglip2Attention(nn.Module):
seq_length
,
embed_dim
=
hidden_states
.
shape
queries
=
self
.
q_proj
(
hidden_states
)
keys
=
self
.
k_proj
(
hidden_states
)
values
=
self
.
v_proj
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
queries
,
keys
,
values
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
queries
=
queries
.
view
(
seq_length
,
self
.
num_heads
,
self
.
head_dim
)
keys
=
keys
.
view
(
seq_length
,
self
.
num_heads
,
self
.
head_dim
)
values
=
values
.
view
(
seq_length
,
self
.
num_heads
,
self
.
head_dim
)
queries
=
queries
.
view
(
seq_length
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
keys
=
keys
.
view
(
seq_length
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
values
=
values
.
view
(
seq_length
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
use_rope
:
cos
,
sin
=
position_embeddings
...
...
@@ -276,41 +307,72 @@ class Siglip2Attention(nn.Module):
v_i
,
dropout_p
=
0.0
)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i
=
output_i
.
transpose
(
1
,
2
).
reshape
(
-
1
,
self
.
embed_dim
)
output_i
=
output_i
.
transpose
(
1
,
2
).
reshape
(
end_idx
-
start_idx
,
-
1
)
outputs
.
append
(
output_i
)
attn_output
=
torch
.
cat
(
outputs
,
dim
=
0
)
attn_output
=
self
.
out_proj
(
attn_output
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
class
Siglip2MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
Siglip2EncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
Siglip2Attention
(
config
)
self
.
self_attn
=
Siglip2Attention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
use_data_parallel
=
use_data_parallel
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
Siglip2MLP
(
config
)
self
.
mlp
=
Siglip2MLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
)
->
tuple
[
torch
.
FloatTensor
]:
...
...
@@ -347,14 +409,22 @@ class Siglip2Encoder(nn.Module):
config: PretrainedConfig
"""
def
__init__
(
self
,
config
:
PretrainedConfig
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
Siglip2EncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
Siglip2EncoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
use_data_parallel
=
use_data_parallel
)
for
idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
gradient_checkpointing
=
False
self
.
rotary_pos_emb
=
VisionRotaryEmbedding
(
config
.
hidden_size
//
config
.
num_attention_heads
//
2
)
...
...
@@ -445,13 +515,11 @@ class Siglip2Encoder(nn.Module):
return
window_index
,
cu_window_seqlens
# Ignore copy
def
forward
(
self
,
inputs_embeds
,
inputs_embeds
:
torch
.
Tensor
,
grid_thws
:
torch
.
Tensor
,
output_hidden_states
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
tuple
[
torch
.
Tensor
,
...]]]:
)
->
torch
.
Tensor
:
r
"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
...
...
@@ -506,7 +574,6 @@ class Siglip2Encoder(nn.Module):
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
value
=
0
)
reverse_indices
=
torch
.
argsort
(
window_index
)
encoder_states
=
()
if
output_hidden_states
else
None
hidden_states
=
inputs_embeds
for
index
,
block
in
enumerate
(
self
.
layers
):
...
...
@@ -517,45 +584,40 @@ class Siglip2Encoder(nn.Module):
cu_seqlens_tmp
=
cu_window_seqlens
hidden_states
=
block
(
hidden_states
,
cu_seqlens_tmp
,
position_embeddings
)
if
output_hidden_states
:
hidden_states_
=
hidden_states
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
encoder_states
+=
(
hidden_states_
[
reverse_indices
,
:].
reshape
(
seq_len
,
-
1
),
)
# tokens = self.post_trunk_norm(tokens)
hidden_states
=
hidden_states
.
reshape
(
seq_len
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
)
hidden_states
=
hidden_states
[
reverse_indices
,
:].
reshape
(
seq_len
,
-
1
)
return
hidden_states
,
encoder_states
return
hidden_states
class
Siglip2VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
Siglip2VisionEmbeddings
(
config
)
self
.
encoder
=
Siglip2Encoder
(
config
)
self
.
encoder
=
Siglip2Encoder
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
use_data_parallel
=
use_data_parallel
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
_use_flash_attention_2
=
\
(
config
.
_attn_implementation
==
"flash_attention_2"
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
grid_thws
:
torch
.
LongTensor
,
output_hidden_states
:
Optional
[
bool
]
=
True
,
return_dict
:
Optional
[
bool
]
=
True
,
)
->
Union
[
tuple
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]],
BaseModelOutputWithNoAttention
,
]:
)
->
torch
.
Tensor
:
r
"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
...
...
@@ -563,45 +625,64 @@ class Siglip2VisionTransformer(nn.Module):
"""
hidden_states
=
self
.
embeddings
(
pixel_values
,
grid_thws
)
last_hidden_state
,
hidden_states
=
self
.
encoder
(
hidden_states
,
grid_thws
,
output_hidden_states
)
last_hidden_state
=
self
.
encoder
(
hidden_states
,
grid_thws
)
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
if
not
return_dict
:
output
=
(
last_hidden_state
,
)
output
+=
(
hidden_states
,
)
if
output_hidden_states
else
()
return
output
return
last_hidden_state
class
Siglip2NavitModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
def
__init__
(
self
,
config
:
Siglip2VisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
self
.
vision_model
=
Siglip2VisionTransformer
(
config
)
self
.
vision_model
=
Siglip2VisionTransformer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
use_data_parallel
=
use_data_parallel
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
grid_thws
:
torch
.
LongTensor
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
tuple
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]],
BaseModelOutputWithNoAttention
,
]:
if
output_hidden_states
is
None
:
output_hidden_states
=
self
.
config
.
output_hidden_states
if
return_dict
is
None
:
return_dict
=
self
.
config
.
use_return_dict
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
pixel_values
=
pixel_values
,
grid_thws
=
grid_thws
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment