Unverified Commit 614fef16 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Ensure OpenAI GPT position_ids is correctly initialized and registered at init. (#5773)



* Ensure OpenAI GPT position_ids is correctly initialized and registered as buffer at init.

This will make it compatible with TorchScript export.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fix missing slice operator on the tensor data accessor.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fixed BertEmbedding position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fixed MobileBertEmbedding position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fixed XLM position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
parent 3b44aa93
...@@ -180,6 +180,9 @@ class BertEmbeddings(nn.Module): ...@@ -180,6 +180,9 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
...@@ -187,12 +190,12 @@ class BertEmbeddings(nn.Module): ...@@ -187,12 +190,12 @@ class BertEmbeddings(nn.Module):
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = self.position_ids[:, :seq_length]
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
......
...@@ -179,18 +179,22 @@ class MobileBertEmbeddings(nn.Module): ...@@ -179,18 +179,22 @@ class MobileBertEmbeddings(nn.Module):
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size) self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = self.position_ids[:, :seq_length]
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
......
...@@ -391,6 +391,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -391,6 +391,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.register_buffer("position_ids", torch.arange(config.n_positions))
self.init_weights() self.init_weights()
def get_input_embeddings(self): def get_input_embeddings(self):
...@@ -443,9 +444,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -443,9 +444,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if position_ids is None: if position_ids is None:
# Code is different from when we had a single embedding matrice from position and token embeddings # Code is different from when we had a single embedding matrice from position and token embeddings
device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.position_ids[None, : input_shape[-1]]
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -442,6 +442,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -442,6 +442,7 @@ class XLMModel(XLMPreTrainedModel):
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
self.init_weights() self.init_weights()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
...@@ -511,12 +512,9 @@ class XLMModel(XLMPreTrainedModel): ...@@ -511,12 +512,9 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids # position_ids
if position_ids is None: if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device) position_ids = self.position_ids[:, :slen]
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else: else:
assert position_ids.size() == (bs, slen) # (slen, bs) assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment