Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1359 additions
and
290 deletions
+1359
-290
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+0
-1
vllm/model_executor/models/audioflamingo3.py
vllm/model_executor/models/audioflamingo3.py
+639
-0
vllm/model_executor/models/bagel.py
vllm/model_executor/models/bagel.py
+584
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+0
-1
vllm/model_executor/models/bailing_moe.py
vllm/model_executor/models/bailing_moe.py
+2
-2
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+2
-5
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+0
-1
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+5
-2
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+0
-1
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+9
-7
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+0
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+18
-23
vllm/model_executor/models/dots1.py
vllm/model_executor/models/dots1.py
+0
-1
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+58
-109
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+0
-1
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+40
-127
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+0
-1
vllm/model_executor/models/exaone4.py
vllm/model_executor/models/exaone4.py
+0
-1
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+0
-1
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+2
-5
No files found.
vllm/model_executor/models/arctic.py
View file @
a3f8d5dd
...
...
@@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
...
...
vllm/model_executor/models/audioflamingo3.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Any
,
Literal
,
TypeAlias
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers.models.audioflamingo3
import
(
AudioFlamingo3Config
,
AudioFlamingo3Processor
,
)
from
transformers.models.qwen2_audio
import
Qwen2AudioEncoder
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
,
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
)
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
,
)
MAX_AUDIO_LEN
=
10
*
60
# === Audio Inputs === #
class
AudioFlamingo3FeatureInputs
(
TensorSchema
):
"""
Dimensions:
- num_chunks: Number of audio chunks (flattened)
- nmb: Number of mel bins
- num_audios: Number of original audio files
"""
type
:
Literal
[
"audio_features"
]
input_features
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"num_chunks"
,
"nmb"
,
3000
),
]
feature_attention_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"num_chunks"
,
3000
),
]
chunk_counts
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"num_audios"
),
]
class
AudioFlamingo3EmbeddingInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type
:
Literal
[
"audio_embeds"
]
=
"audio_embeds"
audio_embeds
:
Annotated
[
list
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"naf"
,
"hs"
),
]
AudioFlamingo3Inputs
:
TypeAlias
=
(
AudioFlamingo3FeatureInputs
|
AudioFlamingo3EmbeddingInputs
)
class
AudioFlamingo3Encoder
(
Qwen2AudioEncoder
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
):
super
().
__init__
(
config
)
self
.
avg_pooler
=
nn
.
AvgPool1d
(
kernel_size
=
2
,
stride
=
2
)
# self.layer_norm is already initialized in super().__init__
def
forward
(
self
,
input_features
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
attention_mask
:
torch
.
Tensor
=
None
,
):
# input_features: (batch, num_mel_bins, seq_len)
if
isinstance
(
input_features
,
list
):
input_features
=
torch
.
stack
(
input_features
)
hidden_states
=
nn
.
functional
.
gelu
(
self
.
conv1
(
input_features
))
hidden_states
=
nn
.
functional
.
gelu
(
self
.
conv2
(
hidden_states
))
hidden_states
=
hidden_states
.
transpose
(
-
1
,
-
2
)
hidden_states
=
(
hidden_states
+
self
.
embed_positions
.
weight
[:
hidden_states
.
size
(
-
2
),
:]
).
to
(
hidden_states
.
dtype
)
for
layer
in
self
.
layers
:
layer_outputs
=
layer
(
hidden_states
,
attention_mask
)
hidden_states
=
layer_outputs
[
0
]
# AvgPool (time/2) + LayerNorm
# hidden_states: (batch, seq_len, hidden_size)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
)
# (batch, hidden_size, seq_len)
hidden_states
=
self
.
avg_pooler
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
)
# (batch, seq_len/2, hidden_size)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
return
hidden_states
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
torch
.
Tensor
):
"""
Computes the output length of the convolutional layers and the output length
of the audio encoder
"""
input_lengths
=
(
input_lengths
-
1
)
//
2
+
1
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
return
input_lengths
,
output_lengths
class
AudioFlamingo3MultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
config
.
audio_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
config
.
projector_bias
,
)
self
.
act
=
get_act_fn
(
config
.
projector_hidden_act
)
self
.
linear_2
=
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
config
.
projector_bias
,
)
def
forward
(
self
,
audio_features
):
hidden_states
=
self
.
linear_1
(
audio_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
class
AudioFlamingo3ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
AudioFlamingo3Config
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
return
self
.
ctx
.
get_hf_processor
(
AudioFlamingo3Processor
,
**
kwargs
)
def
get_feature_extractor
(
self
,
**
kwargs
:
object
):
hf_processor
=
self
.
get_hf_processor
(
**
kwargs
)
feature_extractor
=
hf_processor
.
feature_extractor
return
feature_extractor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"audio"
:
None
}
class
AudioFlamingo3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
AudioFlamingo3ProcessingInfo
]
):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
hf_processor
=
self
.
info
.
get_hf_processor
()
audio_token
=
hf_processor
.
audio_token
return
audio_token
*
num_audios
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
MultiModalDataDict
:
feature_extractor
=
self
.
info
.
get_feature_extractor
()
sampling_rate
=
feature_extractor
.
sampling_rate
audio_len
=
MAX_AUDIO_LEN
*
sampling_rate
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
audio_overrides
=
mm_options
.
get
(
"audio"
)
if
mm_options
else
None
return
{
"audio"
:
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
,
overrides
=
audio_overrides
,
)
}
def
_audioflamingo3_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
chunk_counts
=
hf_inputs
.
get
(
"chunk_counts"
)
if
chunk_counts
is
not
None
:
return
dict
(
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
chunk_counts
,
dim
=
0
),
feature_attention_mask
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
chunk_counts
,
dim
=
0
),
chunk_counts
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
return
dict
(
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
feature_attention_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
chunk_counts
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
class
AudioFlamingo3MultiModalDataParser
(
MultiModalDataParser
):
def
_parse_audio_data
(
self
,
data
:
dict
[
str
,
torch
.
Tensor
]
|
ModalityData
[
Any
],
)
->
ModalityDataItems
[
Any
,
Any
]
|
None
:
if
isinstance
(
data
,
dict
):
return
DictEmbeddingItems
(
data
,
modality
=
"audio"
,
required_fields
=
{
"audio_embeds"
},
fields_factory
=
_audioflamingo3_field_config
,
)
return
super
().
_parse_audio_data
(
data
)
class
AudioFlamingo3MultiModalProcessor
(
BaseMultiModalProcessor
[
AudioFlamingo3ProcessingInfo
]
):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
feature_extractor
=
self
.
info
.
get_feature_extractor
()
return
AudioFlamingo3MultiModalDataParser
(
target_sr
=
feature_extractor
.
sampling_rate
)
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
dict
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
Any
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
audios
=
mm_data
.
pop
(
"audios"
,
[])
if
audios
:
mm_data
[
"audio"
]
=
audios
if
not
mm_data
.
get
(
"audio"
,
[]):
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
feature_extractor
=
self
.
info
.
get_feature_extractor
(
**
mm_kwargs
)
mm_kwargs
=
dict
(
**
mm_kwargs
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
)
# Calculate chunk counts
audio_list
=
mm_data
.
get
(
"audio"
)
if
not
isinstance
(
audio_list
,
list
):
audio_list
=
[
audio_list
]
chunk_counts
=
[]
sampling_rate
=
feature_extractor
.
sampling_rate
chunk_length
=
feature_extractor
.
chunk_length
window_size
=
int
(
sampling_rate
*
chunk_length
)
# MAX_AUDIO_LEN is 10 * 60 in HF processor.
max_windows
=
int
(
MAX_AUDIO_LEN
//
chunk_length
)
for
audio
in
audio_list
:
# audio is numpy array or list
n_samples
=
len
(
audio
)
if
isinstance
(
audio
,
list
)
else
audio
.
shape
[
0
]
n_win
=
max
(
1
,
(
n_samples
+
window_size
-
1
)
//
window_size
)
if
n_win
>
max_windows
:
n_win
=
max_windows
chunk_counts
.
append
(
n_win
)
outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
if
"input_features_mask"
in
outputs
:
outputs
[
"feature_attention_mask"
]
=
outputs
.
pop
(
"input_features_mask"
)
outputs
[
"chunk_counts"
]
=
torch
.
tensor
(
chunk_counts
,
dtype
=
torch
.
long
)
return
outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
_audioflamingo3_field_config
(
hf_inputs
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
audio_token
=
getattr
(
processor
,
"audio_token"
,
"<sound>"
)
audio_token_id
=
vocab
.
get
(
audio_token
)
if
audio_token_id
is
None
:
# Fallback if not found, though it should be there
audio_token_id
=
processor
.
audio_token_id
out_mm_data
=
out_mm_kwargs
.
get_data
()
feature_attention_mask
=
out_mm_data
.
get
(
"feature_attention_mask"
)
chunk_counts
=
out_mm_data
.
get
(
"chunk_counts"
)
def
get_replacement_audioflamingo3
(
item_idx
:
int
):
if
feature_attention_mask
is
not
None
:
if
chunk_counts
is
not
None
:
counts
=
(
chunk_counts
.
tolist
()
if
isinstance
(
chunk_counts
,
torch
.
Tensor
)
else
chunk_counts
)
start_idx
=
sum
(
counts
[:
item_idx
])
count
=
counts
[
item_idx
]
end_idx
=
start_idx
+
count
if
isinstance
(
feature_attention_mask
,
list
):
mask_list
=
feature_attention_mask
[
start_idx
:
end_idx
]
if
len
(
mask_list
)
>
0
and
isinstance
(
mask_list
[
0
],
torch
.
Tensor
):
mask
=
torch
.
stack
(
mask_list
)
else
:
mask
=
torch
.
tensor
(
mask_list
)
else
:
mask
=
feature_attention_mask
[
start_idx
:
end_idx
]
else
:
# feature_attention_mask is list[Tensor] or Tensor
if
isinstance
(
feature_attention_mask
,
list
):
mask
=
feature_attention_mask
[
item_idx
]
else
:
mask
=
feature_attention_mask
[
item_idx
].
unsqueeze
(
0
)
# mask shape: (num_chunks, 3000)
input_lengths
=
mask
.
sum
(
-
1
)
conv_lengths
=
(
input_lengths
-
1
)
//
2
+
1
audio_output_lengths
=
(
conv_lengths
-
2
)
//
2
+
1
num_features
=
audio_output_lengths
.
sum
().
item
()
else
:
audio_embeds
=
out_mm_data
[
"audio_embeds"
][
item_idx
]
num_features
=
audio_embeds
.
shape
[
0
]
if
num_features
==
0
:
raise
ValueError
(
"Audio is too short"
)
audio_tokens
=
[
audio_token_id
]
*
int
(
num_features
)
return
PromptUpdateDetails
.
select_token_id
(
audio_tokens
,
embed_token_id
=
audio_token_id
,
)
return
[
PromptReplacement
(
modality
=
"audio"
,
target
=
audio_token
,
replacement
=
get_replacement_audioflamingo3
,
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
AudioFlamingo3MultiModalProcessor
,
info
=
AudioFlamingo3ProcessingInfo
,
dummy_inputs
=
AudioFlamingo3DummyInputsBuilder
,
)
class
AudioFlamingo3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
"""
AudioFlamingo3 model for conditional generation.
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
It supports multi-chunk audio processing.
"""
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model."
,
connector
=
"multi_modal_projector."
,
tower_model
=
"audio_tower."
,
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
audio_tower
=
AudioFlamingo3Encoder
(
config
.
audio_config
,
)
self
.
multi_modal_projector
=
AudioFlamingo3MultiModalProjector
(
config
)
self
.
quant_config
=
quant_config
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
AudioFlamingo3Inputs
|
None
:
input_features
=
kwargs
.
pop
(
"input_features"
,
None
)
audio_embeds
=
kwargs
.
pop
(
"audio_embeds"
,
None
)
feature_attention_mask
=
kwargs
.
pop
(
"feature_attention_mask"
,
None
)
chunk_counts
=
kwargs
.
pop
(
"chunk_counts"
,
None
)
if
input_features
is
None
and
audio_embeds
is
None
:
return
None
if
audio_embeds
is
not
None
:
return
AudioFlamingo3EmbeddingInputs
(
type
=
"audio_embeds"
,
audio_embeds
=
audio_embeds
)
if
input_features
is
not
None
:
return
AudioFlamingo3FeatureInputs
(
type
=
"audio_features"
,
input_features
=
input_features
,
feature_attention_mask
=
feature_attention_mask
,
chunk_counts
=
chunk_counts
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_audio_input
(
self
,
audio_input
:
AudioFlamingo3Inputs
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
if
audio_input
[
"type"
]
==
"audio_embeds"
:
audio_embeds
=
audio_input
[
"audio_embeds"
]
return
tuple
(
audio_embeds
)
input_features
=
audio_input
[
"input_features"
]
feature_attention_mask
=
audio_input
[
"feature_attention_mask"
]
chunk_counts
=
audio_input
.
get
(
"chunk_counts"
)
if
isinstance
(
input_features
,
list
):
input_features
=
torch
.
cat
(
input_features
,
dim
=
0
)
feature_attention_mask
=
torch
.
cat
(
feature_attention_mask
,
dim
=
0
)
if
chunk_counts
is
None
:
chunk_counts
=
[
1
]
*
input_features
.
shape
[
0
]
elif
isinstance
(
chunk_counts
,
torch
.
Tensor
):
chunk_counts
=
chunk_counts
.
tolist
()
elif
(
isinstance
(
chunk_counts
,
list
)
and
chunk_counts
and
isinstance
(
chunk_counts
[
0
],
torch
.
Tensor
)
):
chunk_counts
=
[
c
.
item
()
for
c
in
chunk_counts
]
# Calculate output lengths
input_lengths
=
feature_attention_mask
.
sum
(
-
1
)
# Conv downsampling
conv_lengths
=
(
input_lengths
-
1
)
//
2
+
1
# AvgPool downsampling
audio_output_lengths
=
(
conv_lengths
-
2
)
//
2
+
1
batch_size
,
_
,
max_mel_seq_len
=
input_features
.
shape
# Calculate max_seq_len after convs (before pooling) for attention mask
max_seq_len
=
(
max_mel_seq_len
-
1
)
//
2
+
1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range
=
(
torch
.
arange
(
0
,
max_seq_len
,
dtype
=
conv_lengths
.
dtype
,
device
=
conv_lengths
.
device
,
)
.
unsqueeze
(
0
)
.
expand
(
batch_size
,
max_seq_len
)
)
lengths_expand
=
conv_lengths
.
unsqueeze
(
-
1
).
expand
(
batch_size
,
max_seq_len
)
# Create mask
padding_mask
=
seq_range
>=
lengths_expand
audio_attention_mask_
=
padding_mask
.
view
(
batch_size
,
1
,
1
,
max_seq_len
).
expand
(
batch_size
,
1
,
max_seq_len
,
max_seq_len
)
audio_attention_mask
=
audio_attention_mask_
.
to
(
dtype
=
self
.
audio_tower
.
conv1
.
weight
.
dtype
,
device
=
self
.
audio_tower
.
conv1
.
weight
.
device
,
)
audio_attention_mask
[
audio_attention_mask_
]
=
float
(
"-inf"
)
# Forward pass
audio_features
=
self
.
audio_tower
(
input_features
,
attention_mask
=
audio_attention_mask
)
# Project
audio_features
=
self
.
multi_modal_projector
(
audio_features
)
# Masking after pooling
num_audios
,
max_audio_tokens
,
embed_dim
=
audio_features
.
shape
audio_output_lengths
=
audio_output_lengths
.
unsqueeze
(
1
)
audio_features_mask
=
(
torch
.
arange
(
max_audio_tokens
)
.
expand
(
num_audios
,
max_audio_tokens
)
.
to
(
audio_output_lengths
.
device
)
<
audio_output_lengths
)
masked_audio_features
=
audio_features
[
audio_features_mask
].
view
(
-
1
,
embed_dim
)
# Split to tuple of embeddings for individual audio input.
chunk_embeddings
=
torch
.
split
(
masked_audio_features
,
audio_output_lengths
.
flatten
().
tolist
()
)
grouped_embeddings
=
[]
current_idx
=
0
for
count
in
chunk_counts
:
audio_chunks
=
chunk_embeddings
[
current_idx
:
current_idx
+
count
]
grouped_embeddings
.
append
(
torch
.
cat
(
audio_chunks
,
dim
=
0
))
current_idx
+=
count
return
tuple
(
grouped_embeddings
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
return
[]
masked_audio_features
=
self
.
_process_audio_input
(
audio_input
)
return
masked_audio_features
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/bagel.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
"""Inference-only BAGEL model compatible with HuggingFace weights.
BAGEL is a unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
"""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
Literal
,
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.processors.bagel
import
BagelProcessor
from
vllm.utils.tensor_schema
import
TensorSchema
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
)
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
class
BagelImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
# Shape: (bn, 3, h, w)
BagelImageInputs
:
TypeAlias
=
BagelImagePixelInputs
class
BagelVisionMLP
(
nn
.
Module
):
"""MLP connector for vision features."""
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
,
out_features
:
int
,
act_layer
:
str
=
"gelu_pytorch_tanh"
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
act
=
get_act_fn
(
act_layer
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
fc2
(
x
)
return
x
class
PositionEmbedding
(
nn
.
Module
):
"""2D position embedding for vision tokens using sin-cos embeddings."""
def
__init__
(
self
,
max_num_patch_per_side
:
int
,
hidden_size
:
int
):
super
().
__init__
()
self
.
max_num_patch_per_side
=
max_num_patch_per_side
self
.
hidden_size
=
hidden_size
# Create learnable 2D position embeddings (frozen sin-cos)
pos_embed
=
self
.
_get_2d_sincos_pos_embed
(
hidden_size
,
max_num_patch_per_side
)
self
.
register_buffer
(
"pos_embed"
,
torch
.
from_numpy
(
pos_embed
).
float
(),
persistent
=
False
,
)
@
staticmethod
def
_get_2d_sincos_pos_embed
(
embed_dim
:
int
,
grid_size
:
int
):
"""Generate 2D sin-cos position embeddings."""
import
numpy
as
np
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
PositionEmbedding
.
_get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
return
pos_embed
@
staticmethod
def
_get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
):
"""Generate 2D sin-cos position embeddings from grid."""
import
numpy
as
np
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
PositionEmbedding
.
_get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
]
)
emb_w
=
PositionEmbedding
.
_get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
]
)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
return
emb
@
staticmethod
def
_get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
):
"""Generate 1D sin-cos position embeddings."""
import
numpy
as
np
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
pos
=
pos
.
reshape
(
-
1
)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
emb_sin
=
np
.
sin
(
out
)
emb_cos
=
np
.
cos
(
out
)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
return
emb
def
forward
(
self
,
position_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
position_ids: Flattened position IDs, shape (N,) where each ID
corresponds to a position in the flattened grid
Returns:
Position embeddings of shape (N, hidden_size)
"""
# Ensure position_ids are on the same device as pos_embed
position_ids
=
position_ids
.
to
(
self
.
pos_embed
.
device
)
return
self
.
pos_embed
[
position_ids
]
class
BagelProcessingInfo
(
BaseProcessingInfo
):
"""Processing information for BAGEL model."""
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
BagelProcessor
:
from
vllm.transformers_utils.processor
import
cached_get_image_processor
image_processor
=
cached_get_image_processor
(
self
.
ctx
.
model_config
.
model
,
revision
=
self
.
ctx
.
model_config
.
revision
,
trust_remote_code
=
self
.
ctx
.
model_config
.
trust_remote_code
,
)
tokenizer
=
self
.
get_tokenizer
()
return
BagelProcessor
(
image_processor
=
image_processor
,
tokenizer
=
tokenizer
,
**
kwargs
,
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
hf_config
=
self
.
get_hf_config
()
# Calculate max tokens per image
# For BAGEL: (vit_max_num_patch_per_side) ** 2
max_num_patches
=
hf_config
.
vit_max_num_patch_per_side
**
2
return
{
"image"
:
max_num_patches
}
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vit_config
=
hf_config
.
vit_config
patch_size
=
vit_config
.
patch_size
# Calculate number of patches
num_patches_h
=
image_height
//
patch_size
num_patches_w
=
image_width
//
patch_size
return
num_patches_h
*
num_patches_w
class
BagelDummyInputsBuilder
(
BaseDummyInputsBuilder
[
BagelProcessingInfo
]):
"""Build dummy inputs for BAGEL model profiling."""
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
# Use a simple placeholder for each image
return
"<|image_pad|>"
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
hf_config
=
self
.
info
.
get_hf_config
()
vit_config
=
hf_config
.
vit_config
# Use the configured image size
image_size
=
vit_config
.
image_size
image_overrides
=
mm_options
.
get
(
"image"
)
if
mm_options
else
None
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
image_size
,
height
=
image_size
,
num_images
=
num_images
,
overrides
=
image_overrides
,
),
}
class
BagelMultiModalProcessor
(
BaseMultiModalProcessor
[
BagelProcessingInfo
]):
"""Multimodal processor for BAGEL model."""
def
_hf_processor_applies_updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptReplacement
]:
"""Replace image placeholders with the correct number of tokens."""
hf_config
=
self
.
info
.
get_hf_config
()
# Get the tokenizer to look up the image token ID
tokenizer
=
self
.
info
.
get_tokenizer
()
image_token_id
=
tokenizer
.
get_vocab
().
get
(
"<|image_pad|>"
)
if
image_token_id
is
None
:
raise
ValueError
(
"Image token '<|image_pad|>' not found in tokenizer vocabulary"
)
def
get_replacement_bagel
(
item_idx
:
int
):
# For BAGEL, calculate number of tokens based on max patch size
num_tokens
=
hf_config
.
vit_max_num_patch_per_side
**
2
# Use the image token ID from tokenizer
return
[
image_token_id
]
*
num_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement_bagel
,
)
]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
Any
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
{
"pixel_values"
:
MultiModalFieldConfig
.
batched
(
"image"
),
}
@
MULTIMODAL_REGISTRY
.
register_processor
(
BagelMultiModalProcessor
,
info
=
BagelProcessingInfo
,
dummy_inputs
=
BagelDummyInputsBuilder
,
)
class
BagelForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
"""
BAGEL: A unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
The image generation part is not supported in vLLM.
"""
# Weight mapping from HF to vLLM
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"language_model."
:
"language_model."
,
"vit_model."
:
"vit_model."
,
"connector."
:
"connector."
,
"vit_pos_embed."
:
"vit_pos_embed."
,
}
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
# Ensure we have a BagelConfig (check by name to handle trust_remote_code)
# When trust_remote_code=True, the config comes from transformers_modules
if
type
(
config
).
__name__
!=
"BagelConfig"
:
raise
ValueError
(
f
"Expected BagelConfig, got
{
type
(
config
).
__name__
}
. "
"Make sure the model config is properly loaded."
)
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
llm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
# Initialize vision model (SigLIP) if visual understanding is enabled
if
config
.
visual_und
:
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
# Also disable head as it's not in checkpoint
vit_config
=
config
.
vit_config
if
vit_config
.
num_hidden_layers
==
27
:
logger
.
warning
(
"Overriding vit_config.num_hidden_layers from 27 to 26 "
"to match the Bagel model checkpoint."
)
vit_config
.
num_hidden_layers
=
26
if
not
hasattr
(
vit_config
,
"vision_use_head"
):
logger
.
warning
(
"Setting vit_config.vision_use_head to False as it is not "
"present in the Bagel model checkpoint."
)
vit_config
.
vision_use_head
=
False
self
.
vit_model
=
SiglipVisionModel
(
config
=
vit_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vit_model"
),
)
# Initialize connector (MLP)
vit_hidden_size
=
config
.
vit_config
.
hidden_size
llm_hidden_size
=
config
.
llm_config
.
hidden_size
self
.
connector
=
BagelVisionMLP
(
in_features
=
vit_hidden_size
,
hidden_features
=
llm_hidden_size
,
out_features
=
llm_hidden_size
,
act_layer
=
config
.
connector_act
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"connector"
),
)
# Position embedding for vision tokens
self
.
vit_pos_embed
=
PositionEmbedding
(
max_num_patch_per_side
=
config
.
vit_max_num_patch_per_side
,
hidden_size
=
llm_hidden_size
,
)
else
:
self
.
vit_model
=
None
self
.
connector
=
None
self
.
vit_pos_embed
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
BagelImageInputs
|
None
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
None
:
return
None
return
BagelImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
)
def
_process_image_input
(
self
,
image_input
:
BagelImageInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
"""Process image inputs through vision encoder and connector."""
pixel_values
=
image_input
[
"pixel_values"
]
# Handle potential extra batch dimension
# Expected shape: (batch_size * num_images, 3, H, W)
# But might receive: (batch_size, num_images, 3, H, W)
if
pixel_values
.
ndim
==
5
:
# Flatten batch and num_images dimensions
batch_size
,
num_images
,
channels
,
height
,
width
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
reshape
(
batch_size
*
num_images
,
channels
,
height
,
width
)
# Get vision features from SigLIP
# pixel_values shape: (batch_size * num_images, 3, H, W)
vision_features
=
self
.
vit_model
(
pixel_values
)
# Pass through connector
vision_embeds
=
self
.
connector
(
vision_features
)
# Add position embeddings
batch_size
,
num_patches
,
hidden_size
=
vision_embeds
.
shape
patch_size
=
self
.
config
.
vit_config
.
patch_size
image_size
=
self
.
config
.
vit_config
.
image_size
# Calculate grid dimensions
num_patches_per_side
=
image_size
//
patch_size
# Create flattened position IDs (0 to num_patches-1)
# For BAGEL, we use extrapolate mode by default
h_coords
=
torch
.
arange
(
num_patches_per_side
,
device
=
vision_embeds
.
device
)
w_coords
=
torch
.
arange
(
num_patches_per_side
,
device
=
vision_embeds
.
device
)
position_ids
=
(
h_coords
[:,
None
]
*
self
.
config
.
vit_max_num_patch_per_side
+
w_coords
).
flatten
()
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
flatten
()
# Add position embeddings
pos_embeds
=
self
.
vit_pos_embed
(
position_ids
)
pos_embeds
=
pos_embeds
.
reshape
(
batch_size
,
num_patches
,
hidden_size
)
# Ensure pos_embeds are on the same device as vision_embeds
pos_embeds
=
pos_embeds
.
to
(
vision_embeds
.
device
)
vision_embeds
=
vision_embeds
+
pos_embeds
# Split by image
return
tuple
(
vision_embeds
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
"""Get multimodal embeddings from input."""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
[]
return
self
.
_process_image_input
(
image_input
)
def
get_language_model
(
self
)
->
nn
.
Module
:
return
self
.
language_model
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
"""Run forward pass for BAGEL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a batch.
positions: Flattened (concatenated) position ids corresponding to a batch.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Load weights from checkpoint."""
skip_prefixes
=
[]
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes
.
append
(
"vit_pos_embed.pos_embed"
)
# If visual understanding is disabled, skip vision-related weights
if
self
.
vit_model
is
None
:
skip_prefixes
.
extend
([
"vit_model."
,
"connector."
,
"vit_pos_embed"
])
# Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components:
# - 'moe_gen': MoE generation weights
# - 'latent_pos_embed': Latent position embeddings for VAE
# - 'llm2vae', 'vae2llm': LLM-VAE projections
# - 'time_embedder': Timestep embeddings for diffusion
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
generation_keywords
=
[
"moe_gen"
,
"latent_pos_embed"
,
"llm2vae"
,
"vae2llm"
,
"time_embedder"
,
]
vae_prefixes
=
[
"decoder."
,
"encoder."
,
]
# VAE encoder/decoder, not vision encoder
filtered_weights
=
[]
for
name
,
tensor
in
weights
:
# Skip generation-related keywords
if
any
(
skip
in
name
for
skip
in
generation_keywords
):
continue
if
any
(
name
.
startswith
(
prefix
)
for
prefix
in
vae_prefixes
):
continue
if
"patch_embedding.weight"
in
name
and
tensor
.
ndim
==
2
:
out_channels
=
tensor
.
shape
[
0
]
in_features
=
tensor
.
shape
[
1
]
patch_size
=
self
.
config
.
vit_config
.
patch_size
in_channels
=
self
.
config
.
vit_config
.
num_channels
if
in_features
==
in_channels
*
patch_size
*
patch_size
:
tensor
=
tensor
.
reshape
(
out_channels
,
patch_size
,
patch_size
,
in_channels
)
tensor
=
tensor
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
filtered_weights
.
append
((
name
,
tensor
))
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
filtered_weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
vllm/model_executor/models/baichuan.py
View file @
a3f8d5dd
...
...
@@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
else
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rope_parameters
=
rope_parameters
,
)
...
...
vllm/model_executor/models/bailing_moe.py
View file @
a3f8d5dd
...
...
@@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
rotary_dim
=
getattr
(
config
,
"rotary_dim"
,
self
.
head_dim
)
rotary_dim
=
getattr
(
config
,
"rotary_dim"
,
self
.
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
self
.
head_dim
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
config
.
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
...
...
vllm/model_executor/models/bamba.py
View file @
a3f8d5dd
...
...
@@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
if
hasattr
(
config
,
"attn_rotary_emb"
):
rotary_dim
=
config
.
attn_rotary_emb
# for backward compatibility
else
:
rotary_dim
=
self
.
head_dim
# default
rotary_dim
=
getattr
(
config
,
"attn_rotary_emb"
,
self
.
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
self
.
head_dim
self
.
rotary_emb
=
get_rope
(
head_size
=
self
.
head_dim
,
rotary_dim
=
rotary_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
...
...
vllm/model_executor/models/chameleon.py
View file @
a3f8d5dd
...
...
@@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
self
.
k_norm
=
ChameleonLayerNorm
((
self
.
num_kv_heads
,
self
.
head_dim
))
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
rope_parameters
,
)
...
...
vllm/model_executor/models/chatglm.py
View file @
a3f8d5dd
...
...
@@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio
=
getattr
(
config
,
"rope_ratio"
,
1.0
)
max_positions
=
getattr
(
config
,
"seq_length"
,
8192
)
rope_parameters
=
{
"rope_type"
:
"default"
,
"rope_theta"
:
10000
*
rope_ratio
}
rope_parameters
=
{
"rope_type"
:
"default"
,
"rope_theta"
:
10000
*
rope_ratio
,
"partial_rotary_factor"
:
0.5
,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style
=
not
config
.
original_rope
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
//
2
,
max_position
=
max_positions
,
rope_parameters
=
rope_parameters
,
is_neox_style
=
is_neox_style
,
...
...
vllm/model_executor/models/commandr.py
View file @
a3f8d5dd
...
...
@@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
False
,
...
...
vllm/model_executor/models/config.py
View file @
a3f8d5dd
...
...
@@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_dim
=
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
head_dim
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"rope_parameters"
:
config
.
rope_parameters
,
}
...
...
@@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if
not
model_config
.
enforce_eager
:
max_position
=
round_up
(
max_position
,
8
)
rotary_dim
=
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
head_dim
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
max_position
,
"rope_parameters"
:
config
.
rope_parameters
,
}
...
...
@@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config
.
num_hidden_layers
=
config
.
n_layer
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_emb_dim
=
int
(
head_dim
*
config
.
rotary_emb_fraction
)
max_trained_positions
=
getattr
(
config
,
"max_trained_positions"
,
2048
)
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
rotary_emb_dim
,
"max_position"
:
max_trained_positions
,
"rope_parameters"
:
config
.
rope_parameters
,
}
...
...
@@ -214,7 +215,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
tokens
=
getattr
(
config
,
"classifier_from_token"
,
None
)
assert
tokens
is
not
None
and
len
(
tokens
)
==
2
,
(
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/
qwen3
_reranker.py"
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/
offline
_reranker.py"
)
vllm_config
.
model_config
.
hf_config
.
method
=
"from_2_way_softmax"
...
...
@@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
config
.
hidden_act
=
"geglu"
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_dim
=
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
head_dim
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"rotary_dim"
:
getattr
(
config
,
"rotary_emb_dim"
,
head_dim
),
"max_position"
:
config
.
max_position_embeddings
,
"rope_parameters"
:
config
.
rope_parameters
,
}
...
...
@@ -361,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
else
:
kernel_block_alignment_size
=
16
if
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
and
model_config
.
get_head_size
()
==
256
and
(
attention_config
.
backend
is
None
...
...
vllm/model_executor/models/dbrx.py
View file @
a3f8d5dd
...
...
@@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position
,
rope_parameters
=
rope_parameters
,
is_neox_style
=
True
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
a3f8d5dd
...
...
@@ -85,6 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata
,
)
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MLAAttentionSpec
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
.interfaces
import
MixtureOfExperts
,
SupportsEagle
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
...
...
@@ -158,7 +159,6 @@ class DeepseekAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
)
...
...
@@ -501,7 +501,6 @@ class DeepseekV2Attention(nn.Module):
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
False
,
...
...
@@ -620,8 +619,15 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
torch
.
float8_e4m3fn
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
...
...
@@ -655,17 +661,17 @@ def sparse_attn_indexer(
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
head_dim
],
device
=
k
.
device
,
dtype
=
fp8_dtype
,
)
k_scale
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
,
)
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
...
...
@@ -781,15 +787,6 @@ def sparse_attn_indexer_fake(
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv
=
torch
.
empty
(
[
total_seq_lens
,
head_dim
+
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
fp8_dtype
).
contiguous
()
_k_scale
=
_flattened_kv
[...,
head_dim
:].
view
(
torch
.
float32
).
contiguous
()
return
topk_indices_buffer
...
...
@@ -1020,7 +1017,6 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
False
,
...
...
@@ -1040,7 +1036,6 @@ class DeepseekV2MLAAttention(nn.Module):
if
self
.
is_v32
:
self
.
indexer_rope_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
...
...
vllm/model_executor/models/dots1.py
View file @
a3f8d5dd
...
...
@@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
)
...
...
vllm/model_executor/models/dots_ocr.py
View file @
a3f8d5dd
...
...
@@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
LayerNorm
from
transformers.models.qwen2_vl
import
Qwen2VLProcessor
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
(
maybe_get_vit_flash_attn_backend
,
from
vllm.attention.layer
s.mm_encoder_attention
import
(
MMEncoderAttention
,
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
MultiModalConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed.parallel_state
import
(
...
...
@@ -30,6 +29,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding.common
import
(
ApplyRotaryEmb
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
...
...
@@ -159,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
return
processor
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb_vision
(
tensor
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
tensor
.
dtype
tensor
=
tensor
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos
=
cos
.
unsqueeze
(
1
).
repeat
(
1
,
1
,
2
).
unsqueeze
(
0
).
float
()
sin
=
sin
.
unsqueeze
(
1
).
repeat
(
1
,
1
,
2
).
unsqueeze
(
0
).
float
()
output
=
(
tensor
*
cos
)
+
(
rotate_half
(
tensor
)
*
sin
)
output
=
output
.
to
(
orig_dtype
)
return
output
class
VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
...
...
@@ -254,11 +230,15 @@ class DotsVisionAttention(nn.Module):
bias
:
bool
=
True
,
*
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
)
->
None
:
super
().
__init__
()
use_data_parallel
=
(
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
multimodal_config
else
False
)
self
.
embed_dim
=
dim
self
.
tp_size
=
(
...
...
@@ -287,31 +267,18 @@ class DotsVisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
disable_tp
=
use_data_parallel
,
)
# Select attention backend
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
hidden_size_per_attention_head
,
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
attn_backend
,
self
.
flash_attn_varlen_func
=
(
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
attn_backend_override
=
attn_backend_override
,
)
self
.
apply_rotary_emb
=
ApplyRotaryEmb
(
enforce_enable
=
True
,
enable_fp32_compute
=
True
,
)
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"Unsupported vision attention backend:
{
self
.
attn_backend
}
"
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}
def
forward
(
self
,
...
...
@@ -319,7 +286,7 @@ class DotsVisionAttention(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
|
None
=
None
,
*
,
max_seqlen
:
int
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# [S, C] -> [S, B=1, C]
x
=
hidden_states
.
unsqueeze
(
1
)
...
...
@@ -333,44 +300,20 @@ class DotsVisionAttention(nn.Module):
if
rotary_pos_emb
is
not
None
:
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb
)
qk_rotated
=
self
.
apply_rotary_emb
(
qk_concat
,
rotary_pos_emb
.
cos
(),
rotary_pos_emb
.
sin
(),
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
q_
=
q
.
reshape
(
bs
*
q
.
shape
[
1
],
q
.
shape
[
2
],
q
.
shape
[
3
])
k_
=
k
.
reshape
(
bs
*
k
.
shape
[
1
],
k
.
shape
[
2
],
k
.
shape
[
3
])
v_
=
v
.
reshape
(
bs
*
v
.
shape
[
1
],
v
.
shape
[
2
],
v
.
shape
[
3
])
output
=
self
.
flash_attn_varlen_func
(
q_
,
k_
,
v_
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0.0
,
causal
=
False
,
)
context_layer
=
output
.
view
(
bs
,
-
1
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
)
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
outputs
=
[]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
s
=
int
(
cu_seqlens
[
i
-
1
])
e
=
int
(
cu_seqlens
[
i
])
q_i
=
q
[:,
s
:
e
].
permute
(
0
,
2
,
1
,
3
)
k_i
=
k
[:,
s
:
e
].
permute
(
0
,
2
,
1
,
3
)
v_i
=
v
[:,
s
:
e
].
permute
(
0
,
2
,
1
,
3
)
out_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
out_i
=
out_i
.
permute
(
0
,
2
,
1
,
3
)
outputs
.
append
(
out_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
if
outputs
else
q
[:,
:
0
]
else
:
raise
RuntimeError
(
"Unsupported attention backend"
)
context_layer
=
self
.
attn
(
query
=
q
,
key
=
k
,
value
=
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
# [B,S,H,D] -> [S,B,H*D] -> [S, C]
context_layer
=
context_layer
.
permute
(
1
,
0
,
2
,
3
).
contiguous
()
...
...
@@ -385,14 +328,19 @@ class DotsSwiGLUFFN(nn.Module):
config
,
*
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
hidden_features
=
config
.
intermediate_size
in_features
=
config
.
embed_dim
bias
=
config
.
use_bias
use_data_parallel
=
(
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
multimodal_config
else
False
)
# Referenced aimv2.py AIMv2SwiGLUFFN
self
.
fc13
=
MergedColumnParallelLinear
(
in_features
,
...
...
@@ -498,9 +446,8 @@ class DotsVisionBlock(nn.Module):
config
,
*
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
):
super
().
__init__
()
...
...
@@ -510,16 +457,15 @@ class DotsVisionBlock(nn.Module):
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
use_bias
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
norm1
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
mlp
=
DotsSwiGLUFFN
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
,
)
self
.
norm2
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -546,12 +492,11 @@ class DotsVisionTransformer(nn.Module):
self
,
config
:
DotsVisionConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
*
,
num_hidden_layers_override
:
int
|
None
=
None
,
require_post_norm
:
bool
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -561,6 +506,11 @@ class DotsVisionTransformer(nn.Module):
head_dim
=
config
.
embed_dim
//
config
.
num_attention_heads
self
.
rotary_pos_emb
=
VisionRotaryEmbedding
(
head_dim
//
2
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
...
...
@@ -578,9 +528,8 @@ class DotsVisionTransformer(nn.Module):
DotsVisionBlock
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
for
i
in
range
(
num_layers
)
]
...
...
@@ -592,6 +541,11 @@ class DotsVisionTransformer(nn.Module):
else
:
self
.
post_trunk_norm
=
None
use_data_parallel
=
(
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
multimodal_config
else
False
)
self
.
merger
=
PatchMerger
(
dim
=
config
.
hidden_size
,
context_dim
=
config
.
embed_dim
,
...
...
@@ -647,7 +601,7 @@ class DotsVisionTransformer(nn.Module):
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
.
item
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
def
forward
(
...
...
@@ -733,17 +687,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self
.
config
.
vision_config
=
vision_config
else
:
vision_config
=
self
.
config
.
vision_config
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
vision_tower
=
DotsVisionTransformer
(
vision_config
,
quant_config
=
self
.
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
language_model
:
Qwen2ForCausalLM
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
...
...
vllm/model_executor/models/ernie45_moe.py
View file @
a3f8d5dd
...
...
@@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
rope_parameters
,
is_neox_style
=
False
,
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
a3f8d5dd
...
...
@@ -33,14 +33,14 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
from
transformers
import
BatchFeature
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
(
maybe_get_vit_flash_attn_backend
,
from
vllm.attention.layer
s.mm_encoder_attention
import
(
MMEncoderAttention
,
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
MultiModalConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding.common
import
(
ApplyRotaryEmb
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
...
...
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
PromptUpdate
,
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
# === Vision Transformer === #
def
rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
interleaved
:
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"... d two -> ... (d two)"
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
sin
=
repeat
(
sin
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:],
],
dim
=-
1
,
)
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
apply_rotary_emb
=
apply_rotary_emb_torch
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
output
=
apply_rotary_emb
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
def
all_gather_interleave
(
local_tensor
,
hidden_size
:
int
,
tp_size
:
int
):
"""All-gather the input tensor interleavely across model parallel group."""
import
torch.distributed
as
dist
...
...
@@ -163,8 +119,8 @@ class Ernie4_5_VisionAttention(nn.Module):
num_heads
:
int
,
projection_size
:
int
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
...
...
@@ -193,33 +149,18 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
,
)
# Detect attention implem
ent
at
ion
.
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
attn
=
MMEncoderAtt
ention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
()
,
attn_backend_override
=
attn_backend_override
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
attn_backend
,
self
.
flash_attn_varlen_func
=
(
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
attn_backend_override
=
attn_backend_override
,
)
self
.
apply_rotary_emb
=
ApplyRotaryEmb
(
enforce_enable
=
True
,
enable_fp32_compute
=
True
,
)
if
self
.
attn_backend
not
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
TORCH_SDPA
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"Ernie45-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
...
...
@@ -253,58 +194,32 @@ class Ernie4_5_VisionAttention(nn.Module):
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
# Only used for Flash Attention
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q
,
k
,
v
=
self
.
split_qkv
(
x
)
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
))
if
rotary_pos_emb
is
not
None
:
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0.0
,
causal
=
False
,
qk_rotated
=
self
.
apply_rotary_emb
(
qk_concat
,
rotary_pos_emb
.
cos
(),
rotary_pos_emb
.
sin
(),
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
context_layer
=
rearrange
(
output
,
"(b s) h d -> s b (h d)"
,
b
=
batch_size
).
contiguous
()
elif
self
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
:
# Execute attention entry by entry for speed & less VRAM.
outputs
=
[]
lens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
q_chunks
=
torch
.
split
(
q
,
lens
,
dim
=
1
)
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
q_i
,
k_i
,
v_i
=
(
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q_i
,
k_i
,
v_i
]
)
output_i
=
F
.
scaled_dot_product_attention
(
q_i
,
k_i
,
v_i
,
dropout_p
=
0.0
)
output_i
=
rearrange
(
output_i
,
"b h s d -> b s h d "
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
output
=
self
.
attn
(
query
=
q
,
key
=
k
,
value
=
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
context_layer
=
rearrange
(
output
,
"b s h d -> s b (h d)"
).
contiguous
()
output
,
_
=
self
.
proj
(
context_layer
)
return
output
...
...
@@ -350,8 +265,8 @@ class Ernie4_5_VisionBlock(nn.Module):
act_layer
:
type
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -366,8 +281,8 @@ class Ernie4_5_VisionBlock(nn.Module):
num_heads
=
num_heads
,
projection_size
=
dim
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
mlp
=
Ernie4_5_VisionMLP
(
...
...
@@ -383,7 +298,7 @@ class Ernie4_5_VisionBlock(nn.Module):
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
# Only used for Flash Attention
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
+
self
.
attn
(
self
.
norm1
(
hidden_states
),
...
...
@@ -441,8 +356,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
vision_config
,
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
AttentionBackendEnum
|
None
=
None
,
)
->
None
:
super
().
__init__
()
patch_size
=
vision_config
.
patch_size
...
...
@@ -477,8 +392,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
mlp_ratio
=
mlp_ratio
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
attn_backend_override
=
attn_backend_override
,
)
for
layer_idx
in
range
(
depth
)
]
...
...
@@ -489,6 +404,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
)
self
.
ln
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
else
None
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
...
...
@@ -535,13 +453,13 @@ class Ernie4_5_VisionTransformer(nn.Module):
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
int
|
None
:
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
max_seqlen
=
None
if
(
self
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
or
self
.
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
.
item
()
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
def
forward
(
...
...
@@ -1304,17 +1222,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
vision_model
=
Ernie4_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
language_model
=
Ernie4_5_VLMoeForCausalLM
(
...
...
vllm/model_executor/models/exaone.py
View file @
a3f8d5dd
...
...
@@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
is_neox_style
,
...
...
vllm/model_executor/models/exaone4.py
View file @
a3f8d5dd
...
...
@@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module):
set_default_rope_theta
(
config
,
default_theta
=
1000000
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
is_neox_style
,
...
...
vllm/model_executor/models/falcon.py
View file @
a3f8d5dd
...
...
@@ -167,7 +167,6 @@ class FalconAttention(nn.Module):
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
)
...
...
vllm/model_executor/models/falcon_h1.py
View file @
a3f8d5dd
...
...
@@ -242,14 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
if
hasattr
(
config
,
"attn_rotary_emb"
):
rotary_dim
=
config
.
attn_rotary_emb
# for backward compatibility
else
:
rotary_dim
=
self
.
head_dim
# default
rotary_dim
=
getattr
(
config
,
"attn_rotary_emb"
,
self
.
head_dim
)
config
.
rope_parameters
[
"partial_rotary_factor"
]
=
rotary_dim
/
self
.
head_dim
self
.
rotary_emb
=
get_rope
(
head_size
=
self
.
head_dim
,
rotary_dim
=
rotary_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment