Unverified Commit c46edfb8 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Resize embeds with DeepSpeed (#32214)

* fix resize when deepspeed

* deepsped uses new embeds

* we needed this
parent fad15fba
...@@ -1980,12 +1980,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1980,12 +1980,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if new_num_tokens is None and pad_to_multiple_of is None: if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds return model_embeds
# Since we are basically resuing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
vocab_size = model_embeds.weight.shape[0]
# Update base model and current model config # Update base model and current model config
if hasattr(self.config, "text_config"): if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = model_embeds.weight.shape[0] self.config.text_config.vocab_size = vocab_size
else: else:
self.config.vocab_size = model_embeds.weight.shape[0] self.config.vocab_size = vocab_size
self.vocab_size = model_embeds.weight.shape[0] self.vocab_size = vocab_size
# Tie weights again if needed # Tie weights again if needed
self.tie_weights() self.tie_weights()
...@@ -2139,7 +2149,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2139,7 +2149,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
params = [old_embeddings.weight, new_embeddings.weight] params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0): with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight.data = new_embeddings.weight.data old_embeddings.weight = new_embeddings.weight
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx` # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment