Unverified Commit 421929b5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

finish (#13593)

parent b5bab710
...@@ -73,6 +73,17 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -73,6 +73,17 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
if torch.distributed.get_rank() == 0:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
else:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
def _create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
...@@ -86,19 +97,9 @@ class Embeddings(nn.Module): ...@@ -86,19 +97,9 @@ class Embeddings(nn.Module):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds: if config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
if is_deepspeed_zero3_enabled(): n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
import deepspeed )
with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
else:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
...@@ -475,23 +476,9 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -475,23 +476,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim) self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
if self.config.sinusoidal_pos_embds: if self.config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
if is_deepspeed_zero3_enabled(): n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
import deepspeed )
with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else: else:
with torch.no_grad(): with torch.no_grad():
if num_position_embeds_diff > 0: if num_position_embeds_diff > 0:
...@@ -502,6 +489,8 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -502,6 +489,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.position_embeddings.weight = nn.Parameter( self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff] old_position_embeddings_weight[:num_position_embeds_diff]
) )
# move position_embeddings to correct device
self.embeddings.position_embeddings.to(self.device)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
......
...@@ -668,6 +668,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -668,6 +668,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.config.d_model, self.config.d_model,
self.padding_idx, self.padding_idx,
) )
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding: def get_position_embeddings(self) -> nn.Embedding:
""" """
...@@ -886,6 +887,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -886,6 +887,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.config.d_model, self.config.d_model,
self.padding_idx, self.padding_idx,
) )
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding: def get_position_embeddings(self) -> nn.Embedding:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment