Unverified Commit 53ee995a authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for token_type_ids during...

Fix for the issue of device-id getting hardcoded for token_type_ids during Tracing for ConvBert (#12287)

* added token_type_ids buffer to fix the issue #5664

* Handling the case that position_id buffer is not registered

* added token_type_ids buffer to fix the issue #5664

* modified to support device conversion when the model is traced
parent 5adf5cab
...@@ -21,6 +21,7 @@ from operator import attrgetter ...@@ -21,6 +21,7 @@ from operator import attrgetter
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -196,9 +197,14 @@ class ConvBertEmbeddings(nn.Module): ...@@ -196,9 +197,14 @@ class ConvBertEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
...@@ -211,8 +217,16 @@ class ConvBertEmbeddings(nn.Module): ...@@ -211,8 +217,16 @@ class ConvBertEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
...@@ -800,6 +814,7 @@ class ConvBertModel(ConvBertPreTrainedModel): ...@@ -800,6 +814,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
else: else:
...@@ -810,7 +825,12 @@ class ConvBertModel(ConvBertPreTrainedModel): ...@@ -810,7 +825,12 @@ class ConvBertModel(ConvBertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment