Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
44c5621d
"...static/style/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cbb63c5bec618354a25583c0861f45d4a01d9812"
Unverified
Commit
44c5621d
authored
May 06, 2021
by
Patrick von Platen
Committed by
GitHub
May 06, 2021
Browse files
fix tests (#11615)
parent
7eee950a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
1 deletion
+8
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+8
-1
No files found.
src/transformers/modeling_utils.py
View file @
44c5621d
...
@@ -1249,6 +1249,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1249,6 +1249,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
has_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
loaded_keys
)
has_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
loaded_keys
)
expects_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
expected_keys
)
expects_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
expected_keys
)
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix
=
not
has_prefix_module
and
expects_prefix_module
remove_prefix
=
not
has_prefix_module
and
expects_prefix_module
add_prefix
=
has_prefix_module
and
not
expects_prefix_module
add_prefix
=
has_prefix_module
and
not
expects_prefix_module
...
@@ -1347,13 +1350,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1347,13 +1350,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def
retrieve_modules_from_names
(
self
,
names
,
add_prefix
=
False
,
remove_prefix
=
False
):
def
retrieve_modules_from_names
(
self
,
names
,
add_prefix
=
False
,
remove_prefix
=
False
):
module_keys
=
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
])
module_keys
=
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
])
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys
=
module_keys
.
union
(
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
2
])
for
key
in
names
if
key
[
-
1
].
isdigit
()]))
retrieved_modules
=
[]
retrieved_modules
=
[]
# retrieve all modules that has at least one missing weight name
# retrieve all modules that has at least one missing weight name
for
name
,
module
in
self
.
named_modules
():
for
name
,
module
in
self
.
named_modules
():
if
remove_prefix
:
if
remove_prefix
:
name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
if
name
.
startswith
(
self
.
base_model_prefix
)
else
name
name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
if
name
.
startswith
(
self
.
base_model_prefix
)
else
name
elif
add_prefix
:
elif
add_prefix
:
name
=
"."
.
join
([
self
.
base_model_prefix
,
name
])
name
=
"."
.
join
([
self
.
base_model_prefix
,
name
])
if
len
(
name
)
>
0
else
self
.
base_model_prefix
if
name
in
module_keys
:
if
name
in
module_keys
:
retrieved_modules
.
append
(
module
)
retrieved_modules
.
append
(
module
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment