Unverified Commit ffbf989f authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

DeepSpeed ZeRO-3 handling when resizing embedding layers (#26259)

* fix failing deepspeed slow tests

* fixes
parent 39df4eca
...@@ -1550,7 +1550,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1550,7 +1550,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens: if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
return old_embeddings return old_embeddings
if not isinstance(old_embeddings, nn.Embedding): if not isinstance(old_embeddings, nn.Embedding):
...@@ -1560,13 +1560,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1560,13 +1560,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f" {nn.Embedding}." f" {nn.Embedding}."
) )
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
# Build new embeddings # Build new embeddings
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
# to errors when training.
new_embeddings = nn.Embedding( new_embeddings = nn.Embedding(
new_num_tokens, new_num_tokens,
old_embedding_dim, old_embedding_dim,
...@@ -1574,26 +1573,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1574,26 +1573,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=old_embeddings.weight.dtype, dtype=old_embeddings.weight.dtype,
) )
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
# initialize all new embeddings (in particular added tokens) # initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings) self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights # Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
# initialize all new embeddings (in particular added tokens) # numbers of tokens to copy
self._init_weights(new_embeddings) n = min(old_num_tokens, new_num_tokens)
# Copy token embeddings from the previous weights if is_deepspeed_zero3_enabled():
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings return new_embeddings
...@@ -1636,7 +1630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1636,7 +1630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
) )
if old_num_tokens == new_num_tokens: if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
return old_lm_head return old_lm_head
if not isinstance(old_lm_head, nn.Linear): if not isinstance(old_lm_head, nn.Linear):
...@@ -1650,39 +1644,40 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1650,39 +1644,40 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None has_new_lm_head_bias = old_lm_head.bias is not None
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# XXX: put the long block of code in a wrapper # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
if is_deepspeed_zero3_enabled(): # to errors when training.
import deepspeed
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
new_lm_head = nn.Linear( new_lm_head = nn.Linear(
*new_lm_head_shape, *new_lm_head_shape,
bias=has_new_lm_head_bias, bias=has_new_lm_head_bias,
device=old_lm_head.weight.device, device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype, dtype=old_lm_head.weight.dtype,
) )
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0): # initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head) self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] if is_deepspeed_zero3_enabled():
import deepspeed
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._copy_lm_head_original_to_resized(
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
)
else: else:
new_lm_head = nn.Linear( self._copy_lm_head_original_to_resized(
*new_lm_head_shape, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
) )
self._init_weights(new_lm_head)
return new_lm_head
def _copy_lm_head_original_to_resized(
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
):
# Copy old lm head weights to new lm head # Copy old lm head weights to new lm head
if not transposed: if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
...@@ -1693,8 +1688,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1693,8 +1688,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if has_new_lm_head_bias: if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
return new_lm_head
def resize_position_embeddings(self, new_num_position_embeddings: int): def resize_position_embeddings(self, new_num_position_embeddings: int):
raise NotImplementedError( raise NotImplementedError(
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment