Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5cc30d7
Unverified
Commit
a5cc30d7
authored
Jul 25, 2023
by
Marc Sun
Committed by
GitHub
Jul 25, 2023
Browse files
fix tied_params for meta tensor (#25101)
* fix tied_params for meta tensor * remove duplicate
parent
f1deb21f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+10
-6
No files found.
src/transformers/modeling_utils.py
View file @
a5cc30d7
...
...
@@ -3060,13 +3060,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys
=
list
(
unexpected_keys
-
model_buffers
)
model
.
tie_weights
()
if
device_map
is
None
:
ptrs
=
collections
.
defaultdict
(
list
)
for
name
,
tensor
in
model
.
state_dict
().
items
():
id_tensor
=
id_tensor_storage
(
tensor
)
if
tensor
.
device
!=
torch
.
device
(
"meta"
)
else
id
(
tensor
)
id_tensor
=
id_tensor_storage
(
tensor
)
ptrs
[
id_tensor
].
append
(
name
)
# These are all the pointers of shared tensors.
tied_params
=
[
names
for
_
,
names
in
ptrs
.
items
()
if
len
(
names
)
>
1
]
else
:
# id function doesn't work for meta tensor so we need this function
tied_params
=
find_tied_parameters
(
model
)
for
group
in
tied_params
:
if
remove_prefix_from_model
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment