Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
141 additions
and
111 deletions
+141
-111
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+21
-14
vllm/model_executor/models/llama_eagle.py
vllm/model_executor/models/llama_eagle.py
+4
-2
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+15
-6
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+2
-3
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+7
-6
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+4
-2
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+1
-4
vllm/model_executor/models/mimo_mtp.py
vllm/model_executor/models/mimo_mtp.py
+1
-1
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+0
-3
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+7
-5
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+2
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+4
-3
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+4
-6
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+1
-1
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+4
-12
vllm/model_executor/models/nemotron_nas.py
vllm/model_executor/models/nemotron_nas.py
+46
-2
vllm/model_executor/models/nvlm_d.py
vllm/model_executor/models/nvlm_d.py
+8
-5
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+2
-14
vllm/model_executor/models/olmo2.py
vllm/model_executor/models/olmo2.py
+7
-15
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+1
-4
No files found.
vllm/model_executor/models/llama.py
View file @
4eabe123
...
...
@@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
is_gguf
=
quant_config
and
quant_config
.
get_name
()
==
"gguf"
if
is_gguf
and
config
.
model_type
==
"llama"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
self
.
_init_rotary_emb
(
config
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
partial_rotary_factor
=
self
.
partial_rotary_factor
,
)
quant_config
=
quant_config
)
if
hasattr
(
config
,
"interleaved_sliding_window"
):
interleaved_sliding_window
=
config
.
interleaved_sliding_window
...
...
@@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
_init_rotary_emb
(
self
,
config
:
LlamaConfig
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]],
quant_config
:
Optional
[
QuantizationConfig
])
->
None
:
is_neox_style
=
True
is_gguf
=
quant_config
and
quant_config
.
get_name
()
==
"gguf"
if
is_gguf
and
config
.
model_type
==
"llama"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
partial_rotary_factor
=
self
.
partial_rotary_factor
,
)
class
LlamaDecoderLayer
(
nn
.
Module
):
...
...
vllm/model_executor/models/llama_eagle.py
View file @
4eabe123
...
...
@@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
class
EagleLlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
,
start_layer_id
=
s
tart_layer_
id
)
start_layer_id
=
tar
ge
t_layer_
num
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
4eabe123
...
...
@@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
class
Eagle3LlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
\
speculative_config
.
draft_model_config
.
hf_config
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
start_layer_id
=
start_layer_id
,
prefix
=
"model"
)
prefix
=
"model"
,
start_layer_id
=
target_layer_num
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
lm_head
=
ParallelLMHead
(
...
...
@@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
draft_vocab_size
,
scale
=
logit_scale
)
self
.
draft_id_to_target_id
=
nn
.
Parameter
(
torch
.
zeros
((
self
.
config
.
draft_vocab_size
),
dtype
=
torch
.
long
).
type
(
torch
.
LongTensor
),
torch
.
zeros
(
self
.
config
.
draft_vocab_size
,
dtype
=
torch
.
long
),
requires_grad
=
False
,
)
...
...
@@ -213,6 +214,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
draft_id_to_target_id
is
None
:
return
logits
base
=
torch
.
arange
(
self
.
config
.
draft_vocab_size
,
device
=
logits
.
device
)
targets
=
base
+
self
.
draft_id_to_target_id
logits_new
=
logits
.
new_full
((
...
...
@@ -245,4 +249,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
name
=
"model."
+
name
model_weights
[
name
]
=
loaded_weight
return
loader
.
load_weights
(
model_weights
.
items
())
loaded_weights
=
loader
.
load_weights
(
model_weights
.
items
())
if
'd2t'
not
in
loaded_weights
:
self
.
draft_id_to_target_id
=
None
return
loaded_weights
vllm/model_executor/models/llava.py
View file @
4eabe123
...
...
@@ -721,9 +721,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
batch.
pixel_values: The pixels in each input image.
:::{seealso}
{class}`LlavaImageInputs`
:::
Info:
[LlavaImageInputs][]
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
...
...
vllm/model_executor/models/llava_next.py
View file @
4eabe123
...
...
@@ -135,11 +135,13 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
current_aspect_ratio
=
current_width
/
current_height
if
aspect_ratio
>
current_aspect_ratio
:
new_height
=
(
original_height
*
current_width
)
//
original_width
new_height
=
int
(
round
(
original_height
*
(
current_width
/
original_width
),
7
))
padding
=
(
current_height
-
new_height
)
//
2
current_height
=
current_height
-
(
2
*
padding
)
else
:
new_width
=
(
original_width
*
current_height
)
//
original_height
new_width
=
int
(
round
(
original_width
*
(
current_height
/
original_height
),
7
))
padding
=
(
current_width
-
new_width
)
//
2
current_width
=
current_width
-
(
2
*
padding
)
...
...
@@ -538,7 +540,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
model depends on the original size of the input image. Including the
original image token in the input, the required number of image tokens
is given by
{func}`
get_llava_next_image_feature_size
`
.
is given by
[
get_llava_next_image_feature_size
][]
.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
...
...
@@ -549,9 +551,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image.
:::{seealso}
{class}`LlavaNextImageInputs`
:::
Info:
[LlavaNextImageInputs][]
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
...
...
vllm/model_executor/models/llava_onevision.py
View file @
4eabe123
...
...
@@ -116,11 +116,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
current_aspect_ratio
=
current_width
/
current_height
if
aspect_ratio
>
current_aspect_ratio
:
new_height
=
(
original_height
*
current_width
)
//
original_width
new_height
=
int
(
round
(
original_height
*
(
current_width
/
original_width
),
7
))
padding
=
(
current_height
-
new_height
)
//
2
current_height
=
current_height
-
(
2
*
padding
)
else
:
new_width
=
(
original_width
*
current_height
)
//
original_height
new_width
=
int
(
round
(
original_width
*
(
current_height
/
original_height
),
7
))
padding
=
(
current_width
-
new_width
)
//
2
current_width
=
current_width
-
(
2
*
padding
)
...
...
vllm/model_executor/models/medusa.py
View file @
4eabe123
...
...
@@ -51,10 +51,7 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
if
hasattr
(
vllm_config
,
'draft_model_config'
):
config
=
vllm_config
.
draft_model_config
.
hf_config
else
:
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
super
().
__init__
()
self
.
config
=
config
self
.
blocks
=
nn
.
ModuleList
([
...
...
vllm/model_executor/models/mimo_mtp.py
View file @
4eabe123
...
...
@@ -250,7 +250,7 @@ class MiMoMTP(nn.Module):
return
loaded_params
def
map_model_name_to_mtp_param_name
(
self
,
name
:
str
)
->
str
:
import
re
import
re
gex
as
re
name_without_prefix
=
[
"token_layernorm"
,
"hidden_layernorm"
,
"input_proj"
,
"final_layernorm"
...
...
vllm/model_executor/models/minicpm.py
View file @
4eabe123
...
...
@@ -242,9 +242,6 @@ class MiniCPMAttention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
# set rope as fp32 instead of bf16
self
.
rotary_emb
.
cos_sin_cache
=
self
.
rotary_emb
.
_compute_cos_sin_cache
(
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
4eabe123
...
...
@@ -2,10 +2,10 @@
"""Inference-only MiniMaxText01 model."""
import
copy
import
math
import
re
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
import
regex
as
re
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
...
...
@@ -604,8 +604,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
head_dim
is
None
:
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
if
hasattr
(
config
,
"max_model_len"
)
and
isinstance
(
config
.
max_model_len
,
int
):
max_position_embeddings
=
min
(
config
.
max_position_embeddings
,
...
...
@@ -861,8 +862,9 @@ class MiniMaxText01Model(nn.Module):
cache_shape
=
self
.
cache_shape
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
head_dim
is
None
:
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
if
hasattr
(
config
,
"max_model_len"
)
and
isinstance
(
config
.
max_model_len
,
int
):
max_position_embeddings
=
min
(
config
.
max_position_embeddings
,
...
...
vllm/model_executor/models/mistral3.py
View file @
4eabe123
...
...
@@ -559,9 +559,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
batch.
pixel_values: The pixels in each input image.
:::{seealso}
{class}`Mistral3ImagePixelInputs`
:::
Info:
[Mistral3ImagePixelInputs][]
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
...
...
vllm/model_executor/models/mixtral.py
View file @
4eabe123
...
...
@@ -138,8 +138,9 @@ class MixtralAttention(nn.Module):
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MixtralConfig has an optional head_dim argument
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
self
.
head_dim
is
None
:
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -482,5 +483,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"rotary_emb.inv_freq"
]
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mixtral_quant.py
View file @
4eabe123
...
...
@@ -193,8 +193,9 @@ class MixtralAttention(nn.Module):
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MixtralConfig has an optional head_dim argument
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
self
.
head_dim
is
None
:
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -447,8 +448,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"rotary_emb.inv_freq"
]),
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/molmo.py
View file @
4eabe123
...
...
@@ -965,7 +965,7 @@ def select_tiling(
class
MolmoProcessorWrapper
:
"""
Wraps
{class}
`MolmoProcessor` so that it can be called directly.
Wraps `MolmoProcessor` so that it can be called directly.
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
...
...
vllm/model_executor/models/nemotron.py
View file @
4eabe123
...
...
@@ -158,8 +158,9 @@ class NemotronAttention(nn.Module):
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
self
.
head_dim
is
None
:
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -502,14 +503,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
]),
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/nemotron_nas.py
View file @
4eabe123
...
...
@@ -23,18 +23,20 @@
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.attention
import
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
return
n
+
k
-
(
n
%
k
)
class
DeciLMAttention
(
LlamaAttention
):
def
__init__
(
self
,
config
:
LlamaConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
super
().
__init__
(
config
,
hidden_size
,
num_heads
,
num_kv_heads
,
rope_theta
,
rope_scaling
,
max_position_embeddings
,
quant_config
,
bias
,
bias_o_proj
,
cache_config
,
prefix
,
attn_type
)
def
_init_rotary_emb
(
self
,
config
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]],
quant_config
:
Optional
[
QuantizationConfig
])
->
None
:
# Enables YARN for Mistral and LLaMA4 derivatives.
is_neox_style
=
True
if
hasattr
(
config
,
"position_embedding_type"
):
is_neox_style
=
config
.
position_embedding_type
not
in
[
"mistral_yarn"
,
"rope_llama4"
]
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
partial_rotary_factor
=
self
.
partial_rotary_factor
)
class
DeciLMDecoderLayer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
if
not
self
.
_is_no_op_attention
:
num_kv_heads
=
(
config
.
num_attention_heads
//
block_config
.
attention
.
n_heads_in_group
)
self
.
self_attn
=
Llama
Attention
(
self
.
self_attn
=
DeciLM
Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
...
...
vllm/model_executor/models/nvlm_d.py
View file @
4eabe123
...
...
@@ -22,9 +22,10 @@ from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails
)
from
.intern_vit
import
InternVisionModel
from
.internvl
import
(
BaseInternVLProcessingInfo
,
BaseInternVLProcessor
,
InternVLChatModel
,
InternVLDummyInputsBuilder
,
InternVLMultiModalProcessor
)
from
.internvl
import
(
BaseInternVLDummyInputsBuilder
,
BaseInternVLMultiModalProcessor
,
BaseInternVLProcessingInfo
,
BaseInternVLProcessor
,
InternVLChatModel
)
IMG_PAD
=
"<|vision_pad|>"
...
...
@@ -84,7 +85,8 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
)
class
NVLMDummyInputsBuilder
(
InternVLDummyInputsBuilder
[
NVLMProcessingInfo
]):
class
NVLMDummyInputsBuilder
(
BaseInternVLDummyInputsBuilder
[
NVLMProcessingInfo
]
):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
...
...
@@ -110,7 +112,8 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
}
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
class
NVLMMultiModalProcessor
(
BaseInternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
def
_get_prompt_updates
(
self
,
...
...
vllm/model_executor/models/olmo.py
View file @
4eabe123
...
...
@@ -382,19 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
,
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
[
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
]),
skip_prefixes
=
([
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/olmo2.py
View file @
4eabe123
...
...
@@ -314,7 +314,8 @@ class Olmo2Model(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -325,6 +326,7 @@ class Olmo2Model(nn.Module):
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
...
...
@@ -347,6 +349,8 @@ class Olmo2Model(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Olmo2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -403,19 +407,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
,
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
[
"rotary_emb.inv_freq"
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
]),
skip_prefixes
=
([
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/olmoe.py
View file @
4eabe123
...
...
@@ -442,8 +442,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"rotary_emb.inv_freq"
],
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
Prev
1
…
25
26
27
28
29
30
31
32
33
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment