"configs/models/vscode:/vscode.git/clone" did not exist on "2b3d4150f3b036dfb993755847a16e5409dd1a65"
Unverified Commit af6e01c5 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for token_type_ids during Tracing [WIP] (#11252)



* registering a buffer for token_type_ids, to pass the error of device-id getting hardcoded when tracing

* sytle format

* adding persistent flag to the resgitered buffers that prevent from adding them to the state_dict and addresses the Backward compatibility issue

* adding the try catch to the fix as persistent flag is only available from PT >1.6

* adding version check

* added the condition to only use the token_type_ids buffer when its autogenerated not passed by user

* adding comments and making the conidtion where token_type_ids are None to use the registered buffer

* taking out position-embeddding from the if block

* adding comments

* handling the case if buffer for position_ids was not registered

* reverted the changes on position_ids, fix the issue with size of token_type_ids buffer, moved the modification for generated token_type_ids to Bertmodel, instead of Embeddings

* reverting the token_type_ids in case of None to the previous version

* reverting changes on position_ids adding back the if block

* changes added by running make fix-copies

* changes added by running make fix-copies and added the import version as it was getting used

* changes added by running make fix-copies

* changes added by running make fix-copies

* fixing the import format

* fixing the import format

* modified to use temp tensor for trimed and expanded token_type_ids buffer

* changes made by fix-copies after temp tensor modifications

* changes made by fix-copies after temp tensor modifications

* changes made by fix-copies after temp tensor modifications

* clean up

* clean up

* clean up

* clean up

* Nit

* Nit

* Nit

* modified according to support device conversion on traced models

* modified according to support device conversion on traced models

* modified according to support device conversion on traced models

* modified according to support device conversion on traced models

* changes based on latest in master

* Adapt templates

* Add version import
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-32-81.us-west-2.compute.internal>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 0d97ba8a
...@@ -20,6 +20,7 @@ from dataclasses import dataclass ...@@ -20,6 +20,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -216,6 +217,12 @@ class AlbertEmbeddings(nn.Module): ...@@ -216,6 +217,12 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
...@@ -231,7 +238,15 @@ class AlbertEmbeddings(nn.Module): ...@@ -231,7 +238,15 @@ class AlbertEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -687,6 +702,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -687,6 +702,7 @@ class AlbertModel(AlbertPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
else: else:
...@@ -697,6 +713,11 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -697,6 +713,11 @@ class AlbertModel(AlbertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
......
...@@ -24,6 +24,7 @@ from typing import Optional, Tuple ...@@ -24,6 +24,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -176,10 +177,15 @@ class BertEmbeddings(nn.Module): ...@@ -176,10 +177,15 @@ class BertEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
...@@ -194,7 +200,15 @@ class BertEmbeddings(nn.Module): ...@@ -194,7 +200,15 @@ class BertEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -936,7 +950,13 @@ class BertModel(BertPreTrainedModel): ...@@ -936,7 +950,13 @@ class BertModel(BertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
......
...@@ -23,6 +23,7 @@ from typing import Optional, Tuple ...@@ -23,6 +23,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -254,10 +255,15 @@ class BigBirdEmbeddings(nn.Module): ...@@ -254,10 +255,15 @@ class BigBirdEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
# End copy # End copy
self.rescale_embeddings = config.rescale_embeddings self.rescale_embeddings = config.rescale_embeddings
...@@ -276,7 +282,15 @@ class BigBirdEmbeddings(nn.Module): ...@@ -276,7 +282,15 @@ class BigBirdEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -2025,6 +2039,11 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2025,6 +2039,11 @@ class BigBirdModel(BigBirdPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# in order to use block_sparse attention, sequence_length has to be at least # in order to use block_sparse attention, sequence_length has to be at least
......
...@@ -21,6 +21,7 @@ from typing import Optional, Tuple ...@@ -21,6 +21,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -169,6 +170,12 @@ class ElectraEmbeddings(nn.Module): ...@@ -169,6 +170,12 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
...@@ -184,7 +191,15 @@ class ElectraEmbeddings(nn.Module): ...@@ -184,7 +191,15 @@ class ElectraEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -839,6 +854,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -839,6 +854,7 @@ class ElectraModel(ElectraPreTrainedModel):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
else: else:
...@@ -849,6 +865,11 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -849,6 +865,11 @@ class ElectraModel(ElectraPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
......
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -82,10 +83,15 @@ class RobertaEmbeddings(nn.Module): ...@@ -82,10 +83,15 @@ class RobertaEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
# End copy # End copy
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -99,9 +105,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -99,9 +105,7 @@ class RobertaEmbeddings(nn.Module):
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids( position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
input_ids, self.padding_idx, past_key_values_length
).to(input_ids.device)
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
...@@ -110,7 +114,17 @@ class RobertaEmbeddings(nn.Module): ...@@ -110,7 +114,17 @@ class RobertaEmbeddings(nn.Module):
else: else:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -780,7 +794,13 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -780,7 +794,13 @@ class RobertaModel(RobertaPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
......
...@@ -22,6 +22,7 @@ import os ...@@ -22,6 +22,7 @@ import os
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -156,6 +157,12 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): ...@@ -156,6 +157,12 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
...@@ -170,7 +177,15 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): ...@@ -170,7 +177,15 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -846,7 +861,13 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -846,7 +861,13 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment