Unverified Commit 421929b5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

finish (#13593)

parent b5bab710
......@@ -73,6 +73,17 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_embeddings(n_pos, dim, out):
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
if torch.distributed.get_rank() == 0:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
else:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
def _create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
......@@ -86,16 +97,6 @@ class Embeddings(nn.Module):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
else:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
......@@ -475,22 +476,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
if self.config.sinusoidal_pos_embds:
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
)
else:
with torch.no_grad():
......@@ -502,6 +489,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff]
)
# move position_embeddings to correct device
self.embeddings.position_embeddings.to(self.device)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......
......@@ -668,6 +668,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding:
"""
......@@ -886,6 +887,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.config.d_model,
self.padding_idx,
)
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment