"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1025a9b74291b8256357476e554e252658217e11"
Unverified Commit 5adf5cab authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for position-ids during...

Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Distillbert (#12290)

* registered buffer for position-ids to address issues similar to issue#5664

* added comment

* added the flag to prevent from adding the buffer into the state_dict
parent 5d1a3d13
......@@ -22,6 +22,7 @@ import math
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -101,6 +102,10 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids):
"""
......@@ -111,8 +116,15 @@ class Embeddings(nn.Module):
embeddings)
"""
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# isues similar to issue #5664
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment