Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c4e46433
Unverified
Commit
c4e46433
authored
Nov 18, 2024
by
Isotr0py
Committed by
GitHub
Nov 18, 2024
Browse files
[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
d1557e66
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
124 additions
and
48 deletions
+124
-48
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+7
-2
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+6
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+9
-5
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+6
-2
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+6
-2
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+6
-3
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+6
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+6
-2
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+6
-2
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+6
-2
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+6
-2
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+6
-2
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+6
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+4
-3
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+6
-2
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+6
-2
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+6
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+6
-3
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+6
-2
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+8
-4
No files found.
vllm/model_executor/models/medusa.py
View file @
c4e46433
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -156,8 +156,10 @@ class Medusa(nn.Module):
...
@@ -156,8 +156,10 @@ class Medusa(nn.Module):
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
weights_map
=
{}
weights_map
=
{}
...
@@ -181,9 +183,12 @@ class Medusa(nn.Module):
...
@@ -181,9 +183,12 @@ class Medusa(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
if
self
.
token_map
is
not
None
:
if
self
.
token_map
is
not
None
:
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
assert
(
self
.
truncated_vocab_size
assert
(
self
.
truncated_vocab_size
==
self
.
orig_vocab_size
)
or
(
self
.
token_map
is
not
None
)
==
self
.
orig_vocab_size
)
or
(
self
.
token_map
is
not
None
)
return
loaded_params
vllm/model_executor/models/minicpm.py
View file @
c4e46433
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -539,7 +539,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -539,7 +539,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -556,6 +557,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -556,6 +557,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -606,3 +608,5 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -606,3 +608,5 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/minicpmv.py
View file @
c4e46433
...
@@ -24,7 +24,7 @@ import math
...
@@ -24,7 +24,7 @@ import math
import
re
import
re
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.types
import
torch.types
...
@@ -602,7 +602,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -602,7 +602,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -612,6 +613,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -612,6 +613,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
if
key_to_modify
in
name
:
...
@@ -630,10 +632,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -630,10 +632,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
if
is_pp_missing_parameter
(
name
=
name
.
replace
(
weight_name
,
param_name
)
name
.
replace
(
weight_name
,
param_
name
)
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
...
@@ -646,6 +648,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -646,6 +648,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
"""
...
...
vllm/model_executor/models/mixtral.py
View file @
c4e46433
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -404,7 +404,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -404,7 +404,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -421,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -421,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
num_experts
=
self
.
config
.
num_local_experts
)
num_experts
=
self
.
config
.
num_local_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -478,3 +480,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -478,3 +480,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/mixtral_quant.py
View file @
c4e46433
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -409,7 +409,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -409,7 +409,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -418,6 +419,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -418,6 +419,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -448,3 +450,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -448,3 +450,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/mllama.py
View file @
c4e46433
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""PyTorch Mllama model."""
"""PyTorch Mllama model."""
import
math
import
math
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
import
numpy
as
np
import
numpy
as
np
...
@@ -1427,7 +1427,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1427,7 +1427,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
outputs
return
outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
@@ -1437,7 +1438,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1437,7 +1438,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
(
".gate_up_proj"
,
".up_proj"
,
1
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
=
set
()
updated_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
'patch_embedding.weight'
in
name
:
if
'patch_embedding.weight'
in
name
:
name
=
name
.
replace
(
'patch_embedding.weight'
,
name
=
name
.
replace
(
'patch_embedding.weight'
,
...
@@ -1457,6 +1458,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1457,6 +1458,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
return
updated_params
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
...
...
vllm/model_executor/models/mlp_speculator.py
View file @
c4e46433
import
math
import
math
from
typing
import
Iterable
,
List
,
Tuple
from
typing
import
Iterable
,
List
,
Set
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -188,11 +188,15 @@ class MLPSpeculator(nn.Module):
...
@@ -188,11 +188,15 @@ class MLPSpeculator(nn.Module):
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
param
=
params_dict
.
get
(
name
.
replace
(
"speculator."
,
""
))
param
=
params_dict
.
get
(
name
.
replace
(
"speculator."
,
""
))
if
param
is
not
None
:
if
param
is
not
None
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/mpt.py
View file @
c4e46433
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -324,8 +324,10 @@ class MPTForCausalLM(nn.Module, SupportsPP):
...
@@ -324,8 +324,10 @@ class MPTForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
...
@@ -336,3 +338,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
...
@@ -336,3 +338,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/nemotron.py
View file @
c4e46433
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights."""
"""Inference-only Nemotron model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -474,7 +474,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -474,7 +474,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
@@ -482,6 +483,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -482,6 +483,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -522,3 +524,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -522,3 +524,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/olmo.py
View file @
c4e46433
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -356,7 +356,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -356,7 +356,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -366,6 +367,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -366,6 +367,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -402,3 +404,5 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -402,3 +404,5 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/olmoe.py
View file @
c4e46433
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -364,7 +364,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -364,7 +364,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -383,6 +384,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -383,6 +384,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
num_experts
=
self
.
config
.
num_experts
)
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -455,3 +457,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -455,3 +457,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/opt.py
View file @
c4e46433
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
"""Inference-only OPT model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -394,7 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
...
@@ -394,7 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -402,6 +403,7 @@ class OPTForCausalLM(nn.Module, SupportsPP):
...
@@ -402,6 +403,7 @@ class OPTForCausalLM(nn.Module, SupportsPP):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
and
self
.
config
.
tie_word_embeddings
:
if
"lm_head.weight"
in
name
and
self
.
config
.
tie_word_embeddings
:
continue
continue
...
@@ -431,3 +433,5 @@ class OPTForCausalLM(nn.Module, SupportsPP):
...
@@ -431,3 +433,5 @@ class OPTForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/orion.py
View file @
c4e46433
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# Copyright (c) OrionStar Inc.
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -327,7 +327,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -327,7 +327,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -337,6 +338,7 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -337,6 +338,7 @@ class OrionForCausalLM(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -368,3 +370,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -368,3 +370,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/paligemma.py
View file @
c4e46433
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
import
torch
import
torch
...
@@ -295,6 +295,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -295,6 +295,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/persimmon.py
View file @
c4e46433
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
"""Inference-only persimmon model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -324,8 +324,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
...
@@ -324,8 +324,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -358,3 +360,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
...
@@ -358,3 +360,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/phi.py
View file @
c4e46433
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(
"qkv_proj"
,
"v_proj"
,
"v"
)
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
...
@@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/phi3_small.py
View file @
c4e46433
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -457,9 +457,11 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -457,9 +457,11 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -471,3 +473,5 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -471,3 +473,5 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/phi3v.py
View file @
c4e46433
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
itertools
import
itertools
import
re
import
re
from
functools
import
cached_property
,
lru_cache
from
functools
import
cached_property
,
lru_cache
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
numpy
as
np
...
@@ -744,7 +744,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -744,7 +744,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
PoolerOutput
]:
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
orig_to_new_prefix
=
{
"model.vision_embed_tokens.wte"
:
"embed_tokens"
,
"model.vision_embed_tokens.wte"
:
"embed_tokens"
,
...
@@ -759,5 +760,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -759,5 +760,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The HF config doesn't specify whether these are tied,
# The HF config doesn't specify whether these are tied,
# so we detect it this way
# so we detect it this way
if
"embed_tokens"
not
in
autoloaded_weights
:
if
"embed_tokens
.weight
"
not
in
autoloaded_weights
:
self
.
embed_tokens
=
self
.
language_model
.
model
.
embed_tokens
self
.
embed_tokens
=
self
.
language_model
.
model
.
embed_tokens
autoloaded_weights
.
add
(
"embed_tokens.weight"
)
return
autoloaded_weights
vllm/model_executor/models/phimoe.py
View file @
c4e46433
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only PhiMoE model."""
"""Inference-only PhiMoE model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -598,7 +598,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -598,7 +598,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -613,6 +614,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -613,6 +614,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
num_experts
=
self
.
config
.
num_local_experts
)
num_experts
=
self
.
config
.
num_local_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -666,3 +668,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -666,3 +668,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/pixtral.py
View file @
c4e46433
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
functools
import
cached_property
from
functools
import
cached_property
from
itertools
import
tee
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
import
numpy
import
torch
import
torch
...
@@ -1053,7 +1053,8 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -1053,7 +1053,8 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
@@ -1063,6 +1064,7 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -1063,6 +1064,7 @@ class PixtralHFVisionModel(nn.Module):
(
".gate_up_proj"
,
".up_proj"
,
1
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
transformer
.
layers
)
layer_count
=
len
(
self
.
transformer
.
layers
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
...
@@ -1075,8 +1077,8 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -1075,8 +1077,8 @@ class PixtralHFVisionModel(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
...
@@ -1085,3 +1087,5 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -1085,3 +1087,5 @@ class PixtralHFVisionModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment