Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fa6ecb9a
Unverified
Commit
fa6ecb9a
authored
Nov 29, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 29, 2024
Browse files
[Model] Clean up MiniCPMV (#10751)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
c83919c7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
149 additions
and
215 deletions
+149
-215
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+16
-3
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
...els/decoder_only/vision_language/vlm_utils/model_utils.py
+12
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-5
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+79
-74
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+2
-3
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+33
-103
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+2
-26
No files found.
tests/models/decoder_only/vision_language/test_models.py
View file @
fa6ecb9a
...
@@ -295,16 +295,29 @@ VLM_TEST_SETTINGS = {
...
@@ -295,16 +295,29 @@ VLM_TEST_SETTINGS = {
)
)
],
],
),
),
"minicpmv"
:
VLMTestInfo
(
"minicpmv
_25
"
:
VLMTestInfo
(
models
=
[
"openbmb/MiniCPM-Llama3-V-2_5"
],
models
=
[
"openbmb/MiniCPM-Llama3-V-2_5"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
test_type
=
VLMTestType
.
IMAGE
,
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
img_prompt
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
,
# noqa: E501
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
img_prompt
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"(<image>./</image>)
\n
"
,
img_idx_to_prompt
=
lambda
idx
:
"(<image>./</image>)
\n
"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
get_stop_token_ids
=
lambda
tok
:
[
tok
.
eos_id
,
tok
.
eot_id
],
get_stop_token_ids
=
lambda
tok
:
[
tok
.
eos_id
,
tok
.
eot_id
],
postprocess_inputs
=
model_utils
.
wrap_inputs_post_processor
,
postprocess_inputs
=
model_utils
.
wrap_inputs_post_processor
,
hf_output_post_proc
=
model_utils
.
minicmpv_trunc_hf_output
,
hf_output_post_proc
=
model_utils
.
minicpmv_trunc_hf_output
,
),
"minicpmv_26"
:
VLMTestInfo
(
models
=
[
"openbmb/MiniCPM-V-2_6"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
img_prompt
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"(<image>./</image>)
\n
"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
get_stop_token_ids
=
lambda
tok
:
tok
.
convert_tokens_to_ids
([
'<|im_end|>'
,
'<|endoftext|>'
]),
# noqa: E501
postprocess_inputs
=
model_utils
.
ignore_inputs_post_processor
(
"image_sizes"
),
hf_output_post_proc
=
model_utils
.
minicpmv_trunc_hf_output
,
),
),
# Tests for phi3v currently live in another file because of a bug in
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
# transformers. Once this issue is fixed, we can enable them here instead.
...
...
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
View file @
fa6ecb9a
...
@@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
...
@@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
####### Post-processors for HF outputs
####### Post-processors for HF outputs
def
minic
m
pv_trunc_hf_output
(
hf_output
:
RunnerOutput
,
def
minicp
m
v_trunc_hf_output
(
hf_output
:
RunnerOutput
,
model
:
str
)
->
RunnerOutput
:
model
:
str
)
->
RunnerOutput
:
output_ids
,
output_str
,
out_logprobs
=
hf_output
output_ids
,
output_str
,
out_logprobs
=
hf_output
if
output_str
.
endswith
(
"<|eot_id|>"
):
if
output_str
.
endswith
(
"<|eot_id|>"
):
...
@@ -197,6 +197,17 @@ def get_key_type_post_processor(
...
@@ -197,6 +197,17 @@ def get_key_type_post_processor(
return
process
return
process
def
ignore_inputs_post_processor
(
hf_inp_key
:
str
)
->
Callable
[[
BatchEncoding
,
str
],
BatchEncoding
]:
"""Gets a handle to a post processor which ignores a given key."""
def
process
(
hf_inputs
:
BatchEncoding
,
dtype
:
str
):
del
hf_inputs
[
hf_inp_key
]
return
hf_inputs
return
process
def
wrap_inputs_post_processor
(
hf_inputs
:
BatchEncoding
,
dtype
:
str
):
def
wrap_inputs_post_processor
(
hf_inputs
:
BatchEncoding
,
dtype
:
str
):
return
{
"model_inputs"
:
hf_inputs
}
return
{
"model_inputs"
:
hf_inputs
}
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
fa6ecb9a
...
@@ -242,7 +242,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -242,7 +242,7 @@ class FusedMoE(torch.nn.Module):
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
expert_data
:
torch
.
Tensor
,
expert_data
:
torch
.
Tensor
,
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
t
ensor
,
loaded_weight
:
torch
.
T
ensor
,
tp_rank
:
int
):
tp_rank
:
int
):
# Load grouped weight scales for group quantization
# Load grouped weight scales for group quantization
# or model weights
# or model weights
...
@@ -261,7 +261,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -261,7 +261,7 @@ class FusedMoE(torch.nn.Module):
def
_load_per_channel_weight_scale
(
self
,
expert_data
:
torch
.
Tensor
,
def
_load_per_channel_weight_scale
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
t
ensor
,
loaded_weight
:
torch
.
T
ensor
,
tp_rank
:
int
):
tp_rank
:
int
):
# for per channel weight quantization
# for per channel weight quantization
if
shard_id
==
"w2"
:
if
shard_id
==
"w2"
:
...
@@ -274,7 +274,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -274,7 +274,7 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
)
tp_rank
=
tp_rank
)
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
t
ensor
,
tp_rank
:
int
):
shard_id
:
str
,
loaded_weight
:
torch
.
T
ensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
...
@@ -292,7 +292,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -292,7 +292,7 @@ class FusedMoE(torch.nn.Module):
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
def
_load_w2
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
t
ensor
,
tp_rank
:
int
):
shard_id
:
str
,
loaded_weight
:
torch
.
T
ensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
...
@@ -311,7 +311,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -311,7 +311,7 @@ class FusedMoE(torch.nn.Module):
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
loaded_weight
:
torch
.
t
ensor
,
tp_rank
:
int
):
shard_dim
:
int
,
loaded_weight
:
torch
.
T
ensor
,
tp_rank
:
int
):
if
shard_id
==
"w2"
:
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
self
.
_load_w2
(
shard_id
=
shard_id
,
...
...
vllm/model_executor/models/minicpm.py
View file @
fa6ecb9a
...
@@ -52,7 +52,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -52,7 +52,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -378,6 +378,7 @@ class MiniCPMModel(nn.Module):
...
@@ -378,6 +378,7 @@ class MiniCPMModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
_init_layers
(
prefix
,
config
,
cache_config
,
quant_config
)
self
.
_init_layers
(
prefix
,
config
,
cache_config
,
quant_config
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
...
@@ -437,6 +438,73 @@ class MiniCPMModel(nn.Module):
...
@@ -437,6 +438,73 @@ class MiniCPMModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
num_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -480,8 +548,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -480,8 +548,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
prefix
=
maybe_prefix
(
prefix
,
"model"
))
unpadded_vocab_size
=
config
.
vocab_size
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
@@ -506,8 +575,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -506,8 +575,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
.
model
=
MiniCPMModel
(
vllm_config
=
vllm_config
,
return
MiniCPMModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
prefix
=
maybe_prefix
(
prefix
,
"model"
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
@@ -546,72 +614,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -546,72 +614,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
# (param_name, shard_name, shard_id)
self
,
(
"qkv_proj"
,
"q_proj"
,
"q"
),
skip_prefixes
=
([
"lm_head."
]
(
"qkv_proj"
,
"k_proj"
,
"k"
),
if
self
.
config
.
tie_word_embeddings
else
None
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
return
loader
.
load_weights
(
weights
)
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
num_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/minicpm3.py
View file @
fa6ecb9a
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM
,
MiniCPMForCausalLM
,
MiniCPMModel
)
MiniCPMModel
)
from
.utils
import
make_layers
,
maybe_prefix
from
.utils
import
make_layers
class
MiniCPM3Attention
(
nn
.
Module
):
class
MiniCPM3Attention
(
nn
.
Module
):
...
@@ -248,5 +248,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
...
@@ -248,5 +248,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
}
}
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
.
model
=
MiniCPM3Model
(
vllm_config
=
vllm_config
,
return
MiniCPM3Model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
prefix
=
maybe_prefix
(
prefix
,
"model"
))
vllm/model_executor/models/minicpmv.py
View file @
fa6ecb9a
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
math
import
re
import
re
from
functools
import
partial
from
functools
import
cached_property
,
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
Set
,
Tuple
,
TypedDict
,
Union
)
...
@@ -37,19 +37,15 @@ from vllm.attention import AttentionMetadata
...
@@ -37,19 +37,15 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
get_2d_sincos_pos_embed
)
get_2d_sincos_pos_embed
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.minicpm
import
MiniCPMForCausalLM
from
vllm.model_executor.models.minicpm
import
MiniCPMModel
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.models.qwen2
import
Qwen2ForCausalLM
from
vllm.model_executor.models.utils
import
LLMWrapper
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.image
import
cached_get_image_processor
...
@@ -58,11 +54,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
...
@@ -58,11 +54,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
is_pp_missing_parameter
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
}
RawImageType
=
Union
[
Image
.
Image
,
torch
.
Tensor
]
RawImageType
=
Union
[
Image
.
Image
,
torch
.
Tensor
]
...
@@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
...
@@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
if
version
==
(
2
,
0
)
or
version
==
(
2
,
5
):
if
version
==
(
2
,
0
)
or
version
==
(
2
,
5
):
return
image_processor
.
\
return
image_processor
.
get_slice_image_placeholder
(
image_size
)
get_slice_image_placeholder
(
image_size
)
return
image_processor
.
get_slice_image_placeholder
(
return
image_processor
.
\
image_size
,
num_image
)
get_slice_image_placeholder
(
image_size
,
num_image
)
prompt
=
inputs
.
get
(
"prompt"
)
prompt
=
inputs
.
get
(
"prompt"
)
token_ids
=
inputs
.
get
(
"prompt_token_ids"
)
token_ids
=
inputs
.
get
(
"prompt_token_ids"
)
...
@@ -400,37 +391,32 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -400,37 +391,32 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
vpm
=
self
.
init_vision_module
(
config
,
self
.
vpm
=
self
.
init_vision_module
(
config
,
quant_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vpm"
))
prefix
=
maybe_prefix
(
prefix
,
"vpm"
))
param_dtype
=
torch
.
get_default_dtype
()
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
self
.
vpm
.
embeddings
.
embed_dim
)
self
.
vpm
.
embeddings
.
embed_dim
)
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
,
self
.
vision_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
=
maybe_prefix
(
prefix
,
"resampler"
))
prefix
,
"resampler"
))
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"llm.lm_head"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
llm
.
make_empty_intermediate_tensors
)
self
.
llm
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
llm
,
"sampler"
):
return
self
.
llm
.
sampler
return
get_sampler
()
def
get_embedding
(
def
get_embedding
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
hasattr
(
self
.
config
,
"scale_emb"
):
vlm_embedding
*=
self
.
config
.
scale_emb
if
image_inputs
is
None
:
# No image
if
image_inputs
is
None
:
# No image
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
...
@@ -575,7 +561,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -575,7 +561,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# for `torch.compile` integration
# for `torch.compile` integration
input_ids
=
None
input_ids
=
None
output
=
self
.
llm
(
output
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
...
@@ -590,9 +576,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -590,9 +576,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
return
self
.
llm
.
compute_logits
(
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
...
@@ -604,52 +588,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -604,52 +588,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
self
)
# (param_name, shard_name, shard_id)
return
loader
.
load_weights
(
weights
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading
=
False
if
self
.
is_default_weight_loading
(
name
):
use_default_weight_loading
=
True
else
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
"""
...
@@ -693,9 +633,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -693,9 +633,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
raise
NotImplementedError
class
MiniCPMV2_0
(
MiniCPMVBaseModel
):
class
MiniCPMV2_0
(
MiniCPMVBaseModel
):
...
@@ -708,8 +645,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -708,8 +645,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
return
LLMWrapper
(
MiniCPMModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
return
MiniCPMForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
name
=
"model"
)
def
init_vision_module
(
def
init_vision_module
(
self
,
self
,
...
@@ -717,11 +653,12 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -717,11 +653,12 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config
:
Optional
[
QuantizationConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
# TODO
:refactor this vision model
# TODO:
refactor this vision model
try
:
try
:
import
timm
import
timm
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install timm==0.9.10"
)
from
ImportError
raise
ImportError
(
"Please install timm==0.9.10"
)
from
ImportError
with
set_default_torch_dtype
(
torch
.
float16
):
with
set_default_torch_dtype
(
torch
.
float16
):
model
=
timm
.
create_model
(
model
=
timm
.
create_model
(
"vit_so400m_patch14_siglip_384.webli"
,
"vit_so400m_patch14_siglip_384.webli"
,
...
@@ -731,6 +668,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -731,6 +668,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
dynamic_img_pad
=
True
,
dynamic_img_pad
=
True
,
)
)
model
=
model
.
to
(
dtype
=
torch
.
get_default_dtype
())
if
(
isinstance
(
model
,
timm
.
models
.
VisionTransformer
)
if
(
isinstance
(
model
,
timm
.
models
.
VisionTransformer
)
and
model
.
attn_pool
is
not
None
):
and
model
.
attn_pool
is
not
None
):
model
.
attn_pool
=
torch
.
nn
.
Identity
()
model
.
attn_pool
=
torch
.
nn
.
Identity
()
...
@@ -759,7 +698,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -759,7 +698,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
)
return
resampler
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
def
get_vision_embedding
(
self
,
self
,
...
@@ -790,9 +729,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -790,9 +729,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return
self
.
get_vision_embedding
(
pixel_values
)
return
self
.
get_vision_embedding
(
pixel_values
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
or
"vpm"
in
name
class
MiniCPMV2_5
(
MiniCPMVBaseModel
,
SupportsLoRA
):
class
MiniCPMV2_5
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -843,8 +779,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -843,8 +779,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
return
LLMWrapper
(
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
return
LlamaForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
name
=
"model"
)
def
init_vision_module
(
def
init_vision_module
(
self
,
self
,
...
@@ -871,7 +806,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -871,7 +806,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
)
return
resampler
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
def
get_vision_embedding
(
self
,
self
,
...
@@ -913,9 +849,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -913,9 +849,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return
self
.
get_vision_embedding
(
all_pixel_values
.
type
(
dtype
),
return
self
.
get_vision_embedding
(
all_pixel_values
.
type
(
dtype
),
patch_attn_mask
,
tgt_sizes
)
patch_attn_mask
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
class
MiniCPMV2_6
(
MiniCPMVBaseModel
,
SupportsLoRA
):
class
MiniCPMV2_6
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -966,8 +899,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -966,8 +899,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
return
LLMWrapper
(
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
return
Qwen2ForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
name
=
"model"
)
def
init_vision_module
(
def
init_vision_module
(
self
,
self
,
...
@@ -995,7 +927,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -995,7 +927,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
)
return
resampler
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
def
get_vision_embedding
(
self
,
self
,
...
@@ -1043,9 +976,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -1043,9 +976,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
return
"resampler"
in
name
_SUPPORT_VERSION
=
{
_SUPPORT_VERSION
=
{
(
2
,
0
):
MiniCPMV2_0
,
(
2
,
0
):
MiniCPMV2_0
,
...
...
vllm/model_executor/models/utils.py
View file @
fa6ecb9a
import
itertools
import
itertools
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
from
typing
import
(
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
,
overload
)
Protocol
,
Set
,
Tuple
,
Union
,
overload
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -560,30 +560,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
...
@@ -560,30 +560,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
return
make_empty_intermediate_tensors
return
make_empty_intermediate_tensors
class
LLMWrapper
(
nn
.
Module
):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
def
__init__
(
self
,
llm
:
nn
.
Module
,
name
:
str
)
->
None
:
super
().
__init__
()
self
.
model_name
=
name
setattr
(
self
,
name
,
llm
)
def
__getattr__
(
self
,
key
:
str
):
llm
=
super
().
__getattr__
(
self
.
model_name
)
if
key
==
self
.
model_name
:
return
llm
return
getattr
(
llm
,
key
)
# We need to explicitly override this
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
llm
=
super
().
__getattr__
(
self
.
model_name
)
return
llm
(
*
args
,
**
kwargs
)
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
def
get_vit_attn_backend
(
support_fa
:
bool
=
False
)
->
_Backend
:
"""
"""
Get the available attention backend for Vision Transformer.
Get the available attention backend for Vision Transformer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment