Unverified Commit 53ee995a authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for token_type_ids during...

Fix for the issue of device-id getting hardcoded for token_type_ids during Tracing for ConvBert (#12287)

* added token_type_ids buffer to fix the issue #5664

* Handling the case that position_id buffer is not registered

* added token_type_ids buffer to fix the issue #5664

* modified to support device conversion when the model is traced
parent 5adf5cab
......@@ -21,6 +21,7 @@ from operator import attrgetter
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -196,9 +197,14 @@ class ConvBertEmbeddings(nn.Module):
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
......@@ -211,7 +217,15 @@ class ConvBertEmbeddings(nn.Module):
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
......@@ -800,6 +814,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
......@@ -810,6 +825,11 @@ class ConvBertModel(ConvBertPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment