Unverified Commit 5adf5cab authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for position-ids during...

Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Distillbert (#12290)

* registered buffer for position-ids to address issues similar to issue#5664

* added comment

* added the flag to prevent from adding the buffer into the state_dict
parent 5d1a3d13
...@@ -22,6 +22,7 @@ import math ...@@ -22,6 +22,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -101,6 +102,10 @@ class Embeddings(nn.Module): ...@@ -101,6 +102,10 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids): def forward(self, input_ids):
""" """
...@@ -111,6 +116,13 @@ class Embeddings(nn.Module): ...@@ -111,6 +116,13 @@ class Embeddings(nn.Module):
embeddings) embeddings)
""" """
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# isues similar to issue #5664
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment