Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
df04959e
Unverified
Commit
df04959e
authored
Sep 07, 2023
by
Kai
Committed by
GitHub
Sep 07, 2023
Browse files
fix _resize_token_embeddings will set lm head size to 0 when enabled deepspeed zero3 (#26024)
parent
e3a97163
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
1 deletion
+11
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+11
-1
No files found.
src/transformers/modeling_utils.py
View file @
df04959e
...
@@ -1437,10 +1437,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1437,10 +1437,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_hook_to_module
(
new_embeddings
,
hook
)
add_hook_to_module
(
new_embeddings
,
hook
)
self
.
set_input_embeddings
(
new_embeddings
)
self
.
set_input_embeddings
(
new_embeddings
)
# Update new_num_tokens with the actual size of new_embeddings
if
pad_to_multiple_of
is
not
None
:
if
is_deepspeed_zero3_enabled
():
import
deepspeed
with
deepspeed
.
zero
.
GatheredParameters
(
new_embeddings
.
weight
,
modifier_rank
=
None
):
new_num_tokens
=
new_embeddings
.
weight
.
shape
[
0
]
else
:
new_num_tokens
=
new_embeddings
.
weight
.
shape
[
0
]
# if word embeddings are not tied, make sure that lm head is resized as well
# if word embeddings are not tied, make sure that lm head is resized as well
if
self
.
get_output_embeddings
()
is
not
None
and
not
self
.
config
.
tie_word_embeddings
:
if
self
.
get_output_embeddings
()
is
not
None
and
not
self
.
config
.
tie_word_embeddings
:
old_lm_head
=
self
.
get_output_embeddings
()
old_lm_head
=
self
.
get_output_embeddings
()
new_lm_head
=
self
.
_get_resized_lm_head
(
old_lm_head
,
new_
embeddings
.
weight
.
shape
[
0
]
)
new_lm_head
=
self
.
_get_resized_lm_head
(
old_lm_head
,
new_
num_tokens
)
if
hasattr
(
old_lm_head
,
"_hf_hook"
):
if
hasattr
(
old_lm_head
,
"_hf_hook"
):
hook
=
old_lm_head
.
_hf_hook
hook
=
old_lm_head
.
_hf_hook
add_hook_to_module
(
new_lm_head
,
hook
)
add_hook_to_module
(
new_lm_head
,
hook
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment