Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc73e982
"vscode:/vscode.git/clone" did not exist on "73df49ef3a220c79abfffc36bdfb4e8dee61226b"
Unverified
Commit
bc73e982
authored
Oct 29, 2024
by
Michael Goin
Committed by
GitHub
Oct 29, 2024
Browse files
[Bugfix] Fix prefix strings for quantized VLMs (#9772)
parent
8d772410
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
288 additions
and
97 deletions
+288
-97
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+8
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+4
-1
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+39
-19
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+39
-17
vllm/model_executor/models/internlm2_ve.py
vllm/model_executor/models/internlm2_ve.py
+12
-4
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+4
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+15
-5
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+8
-2
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+8
-2
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+8
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+26
-8
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+27
-7
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+5
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+14
-5
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+4
-1
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+37
-13
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+6
-2
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+4
-1
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+15
-0
No files found.
vllm/model_executor/model_loader/loader.py
View file @
bc73e982
...
...
@@ -147,15 +147,20 @@ def _get_model_initialization_kwargs(
return
extra_kwargs
def
build_model
(
model_class
:
Type
[
nn
.
Module
],
hf_config
:
PretrainedConfig
,
def
build_model
(
model_class
:
Type
[
nn
.
Module
],
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
*
,
quant_config
:
Optional
[
QuantizationConfig
],
*
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
scheduler_config
:
Optional
[
SchedulerConfig
])
->
nn
.
Module
:
scheduler_config
:
Optional
[
SchedulerConfig
],
prefix
:
Optional
[
str
]
=
None
)
->
nn
.
Module
:
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
multimodal_config
,
scheduler_config
)
if
prefix
:
extra_kwargs
[
"prefix"
]
=
prefix
return
model_class
(
config
=
hf_config
,
cache_config
=
cache_config
,
...
...
vllm/model_executor/models/blip2.py
View file @
bc73e982
...
...
@@ -507,7 +507,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/gemma.py
View file @
bc73e982
...
...
@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -83,16 +84,23 @@ class GemmaMLP(nn.Module):
hidden_act
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
def
forward
(
self
,
x
):
...
...
@@ -104,7 +112,8 @@ class GemmaMLP(nn.Module):
class
GemmaAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
...
...
@@ -112,7 +121,9 @@ class GemmaAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -142,12 +153,14 @@ class GemmaAttention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -186,6 +199,7 @@ class GemmaDecoderLayer(nn.Module):
config
:
GemmaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -198,6 +212,7 @@ class GemmaDecoderLayer(nn.Module):
rope_theta
=
config
.
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -205,6 +220,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -259,8 +275,8 @@ class GemmaModel(nn.Module):
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GemmaDecoderLayer
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
GemmaDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -366,6 +382,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -375,7 +392,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
GemmaModel
(
config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
...
...
vllm/model_executor/models/internlm2.py
View file @
bc73e982
...
...
@@ -30,7 +30,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
InternLM2MLP
(
nn
.
Module
):
...
...
@@ -41,16 +42,23 @@ class InternLM2MLP(nn.Module):
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
RowParallelLinear
(
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
w2
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.w2"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -75,6 +83,7 @@ class InternLM2Attention(nn.Module):
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -108,12 +117,14 @@ class InternLM2Attention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wqkv"
,
)
self
.
wo
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wo"
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -123,12 +134,15 @@ class InternLM2Attention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
):
seq_len
=
qkv
.
shape
[
0
]
...
...
@@ -176,6 +190,7 @@ class InternLMDecoderLayer(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -192,12 +207,14 @@ class InternLMDecoderLayer(nn.Module):
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attention"
,
)
self
.
feed_forward
=
InternLM2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
self
.
attention_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -251,8 +268,8 @@ class InternLM2Model(nn.Module):
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
InternLMDecoderLayer
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
InternLMDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
...
...
@@ -306,14 +323,19 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
InternLM2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
InternLM2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"output"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
...
...
vllm/model_executor/models/internlm2_ve.py
View file @
bc73e982
...
...
@@ -15,7 +15,7 @@ from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP
,
InternLM2Model
)
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
make_layers
from
.utils
import
make_layers
,
maybe_prefix
class
InternLM2VEDecoderLayer
(
nn
.
Module
):
...
...
@@ -25,6 +25,7 @@ class InternLM2VEDecoderLayer(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -41,18 +42,21 @@ class InternLM2VEDecoderLayer(nn.Module):
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attention"
,
)
self
.
feed_forward
=
InternLM2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
self
.
feed_forward_ve
=
InternLM2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward_ve"
,
)
self
.
attention_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -111,8 +115,8 @@ class InternLM2VEModel(InternLM2Model):
super
().
__init__
(
config
,
cache_config
,
quant_config
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
InternLM2VEDecoderLayer
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
InternLM2VEDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
def
forward
(
...
...
@@ -161,6 +165,10 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
cache_config
,
quant_config
)
self
.
model
=
InternLM2VEModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
InternLM2VEModel
(
config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
vllm/model_executor/models/internvl.py
View file @
bc73e982
...
...
@@ -439,7 +439,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
...
...
vllm/model_executor/models/llama.py
View file @
bc73e982
...
...
@@ -55,7 +55,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -500,6 +501,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -510,7 +512,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
@@ -526,6 +528,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
...
...
vllm/model_executor/models/llava.py
View file @
bc73e982
...
...
@@ -210,6 +210,7 @@ def init_vision_tower_for_llava(
quant_config
:
Optional
[
QuantizationConfig
],
*
,
require_post_norm
:
Optional
[
bool
]
=
None
,
prefix
:
str
=
""
,
):
vision_config
=
hf_config
.
vision_config
...
...
@@ -224,23 +225,26 @@ def init_vision_tower_for_llava(
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers
,
require_post_norm
=
require_post_norm
,
prefix
=
prefix
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers
,
require_post_norm
=
require_post_norm
,
prefix
=
prefix
,
)
elif
isinstance
(
vision_config
,
PixtralVisionConfig
):
return
PixtralHFVisionModel
(
vision_config
,
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers
,
require_post_norm
=
require_post_norm
,
prefix
=
prefix
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
@@ -274,14 +278,20 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
,
require_post_norm
=
False
)
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_next.py
View file @
bc73e982
...
...
@@ -293,7 +293,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
,
require_post_norm
=
False
)
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
...
...
@@ -302,7 +305,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
...
...
vllm/model_executor/models/llava_next_video.py
View file @
bc73e982
...
...
@@ -257,14 +257,20 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
,
require_post_norm
=
False
)
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_onevision.py
View file @
bc73e982
...
...
@@ -415,10 +415,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
quant_config
,
require_post_norm
=
False
)
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
...
...
vllm/model_executor/models/minicpmv.py
View file @
bc73e982
...
...
@@ -394,8 +394,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
multimodal_config
=
multimodal_config
self
.
version
=
get_version_by_config
(
self
.
config
)
self
.
llm
=
self
.
init_llm
(
config
,
cache_config
,
quant_config
)
self
.
vpm
=
self
.
init_vision_module
(
config
,
quant_config
)
self
.
llm
=
self
.
init_llm
(
config
,
cache_config
,
quant_config
,
prefix
=
"llm"
)
self
.
vpm
=
self
.
init_vision_module
(
config
,
quant_config
,
prefix
=
"vpm"
)
param_dtype
=
torch
.
get_default_dtype
()
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
...
...
@@ -403,9 +406,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
"llm.lm_head"
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -644,6 +649,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
raise
NotImplementedError
...
...
@@ -651,6 +657,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
nn
.
Module
:
raise
NotImplementedError
...
...
@@ -690,17 +697,20 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
nn
.
Module
:
# TODO :refactor this vision model
try
:
...
...
@@ -819,19 +829,23 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
nn
.
Module
:
model
=
Idefics2VisionTransformer
(
config
.
vision_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
...
...
@@ -935,20 +949,24 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
nn
.
Module
:
model
=
Idefics2VisionTransformer
(
config
.
vision_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
if
self
.
config
.
drop_vision_last_layer
:
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
...
...
vllm/model_executor/models/opt.py
View file @
bc73e982
...
...
@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
...
...
@@ -68,6 +69,7 @@ class OPTAttention(nn.Module):
bias
:
bool
=
True
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
...
...
@@ -85,18 +87,21 @@ class OPTAttention(nn.Module):
total_num_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -118,6 +123,7 @@ class OPTDecoderLayer(nn.Module):
config
:
OPTConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -128,6 +134,7 @@ class OPTDecoderLayer(nn.Module):
bias
=
config
.
enable_bias
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
...
...
@@ -139,6 +146,7 @@ class OPTDecoderLayer(nn.Module):
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
...
...
@@ -147,6 +155,7 @@ class OPTDecoderLayer(nn.Module):
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
...
...
@@ -214,7 +223,8 @@ class OPTDecoder(nn.Module):
self
.
project_out
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.project_out"
)
else
:
self
.
project_out
=
None
...
...
@@ -222,7 +232,8 @@ class OPTDecoder(nn.Module):
self
.
project_in
=
ReplicatedLinear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.project_in"
)
else
:
self
.
project_in
=
None
...
...
@@ -239,7 +250,8 @@ class OPTDecoder(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
OPTDecoderLayer
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
OPTDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -288,9 +300,13 @@ class OPTModel(nn.Module):
config
:
OPTConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
decoder
=
OPTDecoder
(
config
,
cache_config
,
quant_config
)
self
.
decoder
=
OPTDecoder
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
...
...
@@ -335,11 +351,15 @@ class OPTForCausalLM(nn.Module, SupportsPP):
config
:
OPTConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
OPTModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
OPTModel
(
config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
decoder
.
embed_tokens
else
:
...
...
vllm/model_executor/models/paligemma.py
View file @
bc73e982
...
...
@@ -143,14 +143,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
multimodal_config
=
multimodal_config
self
.
vision_tower
=
SiglipVisionModel
(
config
.
vision_config
,
quant_config
)
quant_config
,
prefix
=
"vision_tower"
)
self
.
multi_modal_projector
=
PaliGemmaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
projection_dim
=
config
.
vision_config
.
projection_dim
)
self
.
quant_config
=
quant_config
self
.
language_model
=
GemmaForCausalLM
(
config
.
text_config
,
cache_config
,
quant_config
)
cache_config
,
quant_config
,
prefix
=
"language_model"
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
...
...
vllm/model_executor/models/phi3v.py
View file @
bc73e982
...
...
@@ -71,7 +71,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
def
_init_img_processor
(
hf_config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]):
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
)
->
CLIPVisionModel
:
clip_config
=
CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx
=
hf_config
.
img_processor
.
get
(
'layer_idx'
,
-
2
)
...
...
@@ -86,6 +87,7 @@ def _init_img_processor(hf_config: PretrainedConfig,
clip_config
,
quant_config
,
num_hidden_layers_override
=
num_hidden_layers
,
prefix
=
prefix
,
)
return
img_processor
...
...
@@ -152,15 +154,18 @@ class Phi3ImageEmbeddingBase(nn.Module):
class
Phi3HDImageEmbedding
(
Phi3ImageEmbeddingBase
):
"""Phi3 Image embedding with HD transform."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
])
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
# n_embed or hidden_size
hidden_size
=
config
.
n_embd
if
hasattr
(
config
,
'n_embd'
)
else
config
.
hidden_size
self
.
img_processor
=
_init_img_processor
(
config
,
quant_config
)
self
.
img_processor
=
_init_img_processor
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.img_processor"
)
image_dim_out
=
config
.
img_processor
[
'image_dim_out'
]
self
.
num_img_tokens
=
config
.
img_processor
[
'num_img_tokens'
]
...
...
@@ -537,11 +542,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
prefix
=
"model.embed_tokens"
,
)
# TODO: Optionally initializes this for supporting input embeddings.
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
config
,
quant_config
)
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
config
,
quant_config
,
prefix
=
"model.vision_embed_tokens"
)
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
self
.
language_model
=
LlamaForCausalLM
(
config
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/pixtral.py
View file @
bc73e982
...
...
@@ -164,7 +164,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
...
...
vllm/model_executor/models/qwen2.py
View file @
bc73e982
...
...
@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
Qwen2MLP
(
nn
.
Module
):
...
...
@@ -60,16 +61,23 @@ class Qwen2MLP(nn.Module):
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -92,7 +100,8 @@ class Qwen2Attention(nn.Module):
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
)
->
None
:
rope_scaling
:
Optional
[
Tuple
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -122,12 +131,14 @@ class Qwen2Attention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -142,7 +153,8 @@ class Qwen2Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
...
...
@@ -166,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
config
:
Qwen2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -180,12 +193,15 @@ class Qwen2DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rope_scaling
=
rope_scaling
)
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -241,6 +257,7 @@ class Qwen2Model(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
...
...
@@ -249,7 +266,8 @@ class Qwen2Model(nn.Module):
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers"
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
...
...
@@ -393,6 +411,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
...
...
@@ -412,14 +431,19 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
bc73e982
...
...
@@ -938,7 +938,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config
=
None
,
)
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
...
...
@@ -946,7 +949,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
"lm_head"
)
else
:
self
.
lm_head
=
PPMissingLayer
()
...
...
vllm/model_executor/models/ultravox.py
View file @
bc73e982
...
...
@@ -357,7 +357,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
,
prefix
=
"language_model"
)
if
config
.
text_model_id
is
not
None
:
self
.
secondary_weights
.
append
(
DefaultModelLoader
.
Source
(
model_or_path
=
config
.
text_model_id
,
...
...
vllm/model_executor/models/utils.py
View file @
bc73e982
...
...
@@ -242,6 +242,7 @@ def init_vllm_registered_model(
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
"""
Helper function to initialize an inner model registered to vLLM,
...
...
@@ -257,6 +258,7 @@ def init_vllm_registered_model(
lora_config
=
lora_config
,
multimodal_config
=
multimodal_config
,
scheduler_config
=
scheduler_config
,
prefix
=
prefix
,
)
...
...
@@ -610,3 +612,16 @@ def get_vit_attn_backend() -> _Backend:
else
:
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
def
maybe_prefix
(
prefix
:
str
,
name
:
str
)
->
str
:
"""Add a prefix to a name if the prefix is non-empty.
Args:
prefix: The prefix to add. If empty, no prefix will be added.
name: The name to potentially prefix.
Returns:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return
name
if
not
prefix
else
f
"
{
prefix
}
.
{
name
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment