Unverified Commit 9264fc91 authored by Sina's avatar Sina Committed by GitHub
Browse files

Inconsistency in PreTrainedModel.resize_token_embeddings When ZeRO3 Is Enabled (#25394)

* Inconsistency in PreTrainedModel.resize_token_embeddings

This PR addresses https://github.com/huggingface/transformers/issues/25241

.

In previous implementation when ZeRO stage 3 was enbaled, resize_token_embeddings would create independent PyTorch weights on each device. Here we ensure that new embeddings are created with DeepSpeed init, and are properly partitioned accros devices.

* formatting with black

* adding the removed comments back in

---------
Co-authored-by: default avatarSina Moeini <smoeini@amazon.com>
parent b4d55488
......@@ -1502,24 +1502,40 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f" {nn.Embedding}."
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings
......@@ -1575,11 +1591,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
......@@ -1587,23 +1598,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if torch.distributed.get_rank() == 0:
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[
:num_tokens_to_copy, :
]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[
:, :num_tokens_to_copy
]
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
else:
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment