Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
df04959e
Unverified
Commit
df04959e
authored
Sep 07, 2023
by
Kai
Committed by
GitHub
Sep 07, 2023
Browse files
fix _resize_token_embeddings will set lm head size to 0 when enabled deepspeed zero3 (#26024)
parent
e3a97163
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
1 deletion
+11
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+11
-1
No files found.
src/transformers/modeling_utils.py
View file @
df04959e
...
...
@@ -1437,10 +1437,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_hook_to_module
(
new_embeddings
,
hook
)
self
.
set_input_embeddings
(
new_embeddings
)
# Update new_num_tokens with the actual size of new_embeddings
if
pad_to_multiple_of
is
not
None
:
if
is_deepspeed_zero3_enabled
():
import
deepspeed
with
deepspeed
.
zero
.
GatheredParameters
(
new_embeddings
.
weight
,
modifier_rank
=
None
):
new_num_tokens
=
new_embeddings
.
weight
.
shape
[
0
]
else
:
new_num_tokens
=
new_embeddings
.
weight
.
shape
[
0
]
# if word embeddings are not tied, make sure that lm head is resized as well
if
self
.
get_output_embeddings
()
is
not
None
and
not
self
.
config
.
tie_word_embeddings
:
old_lm_head
=
self
.
get_output_embeddings
()
new_lm_head
=
self
.
_get_resized_lm_head
(
old_lm_head
,
new_
embeddings
.
weight
.
shape
[
0
]
)
new_lm_head
=
self
.
_get_resized_lm_head
(
old_lm_head
,
new_
num_tokens
)
if
hasattr
(
old_lm_head
,
"_hf_hook"
):
hook
=
old_lm_head
.
_hf_hook
add_hook_to_module
(
new_lm_head
,
hook
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment