Unverified Commit 6b82d936 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

reattach hooks when using `resize_token_embeddings` (#25596)

* reattach hooks

* fix style
parent 6c811a32
...@@ -90,6 +90,7 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() ...@@ -90,6 +90,7 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available(): if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.hooks import add_hook_to_module
from accelerate.utils import ( from accelerate.utils import (
check_tied_parameters_on_same_device, check_tied_parameters_on_same_device,
find_tied_parameters, find_tied_parameters,
...@@ -1442,12 +1443,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1442,12 +1443,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings = self.get_input_embeddings() old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
self.set_input_embeddings(new_embeddings) self.set_input_embeddings(new_embeddings)
# if word embeddings are not tied, make sure that lm head is resized as well # if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings() old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0]) new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0])
if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook)
self.set_output_embeddings(new_lm_head) self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings() return self.get_input_embeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment