"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c6d664849bdc580cf813b2d3a555a9b33d31b33d"
Unverified Commit df04959e authored by Kai's avatar Kai Committed by GitHub
Browse files

fix _resize_token_embeddings will set lm head size to 0 when enabled deepspeed zero3 (#26024)

parent e3a97163
...@@ -1437,10 +1437,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1437,10 +1437,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_hook_to_module(new_embeddings, hook) add_hook_to_module(new_embeddings, hook)
self.set_input_embeddings(new_embeddings) self.set_input_embeddings(new_embeddings)
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
new_num_tokens = new_embeddings.weight.shape[0]
# if word embeddings are not tied, make sure that lm head is resized as well # if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings() old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0]) new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
if hasattr(old_lm_head, "_hf_hook"): if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook) add_hook_to_module(new_lm_head, hook)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment