"docs/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "eb82e8b56658d7a3fc352f551d4a447649e625ca"
Unverified Commit 614fef16 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Ensure OpenAI GPT position_ids is correctly initialized and registered at init. (#5773)



* Ensure OpenAI GPT position_ids is correctly initialized and registered as buffer at init.

This will make it compatible with TorchScript export.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fix missing slice operator on the tensor data accessor.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fixed BertEmbedding position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fixed MobileBertEmbedding position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fixed XLM position_ids buffer created at forward.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
parent 3b44aa93
......@@ -180,6 +180,9 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
......@@ -187,12 +190,12 @@ class BertEmbeddings(nn.Module):
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
......
......@@ -179,18 +179,22 @@ class MobileBertEmbeddings(nn.Module):
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
......
......@@ -391,6 +391,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.register_buffer("position_ids", torch.arange(config.n_positions))
self.init_weights()
def get_input_embeddings(self):
......@@ -443,9 +444,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if position_ids is None:
# Code is different from when we had a single embedding matrice from position and token embeddings
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = self.position_ids[None, : input_shape[-1]]
# Attention mask.
if attention_mask is not None:
......
......@@ -442,6 +442,7 @@ class XLMModel(XLMPreTrainedModel):
self.prune_heads({int(layer): list(map(int, heads))})
self.init_weights()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def get_input_embeddings(self):
return self.embeddings
......@@ -511,12 +512,9 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids
if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
position_ids = self.position_ids[:, :slen]
else:
assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment