Unverified Commit 5d1a3d13 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Fix for the issue of device-id getting hardcoded for position-ids during...

Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Flaubert (#12292)

* adding position_ids buffer to fix the issue simialr to #5664

* adding position-id buffer to address similar issues to #5664
parent 58e999b7
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import random import random
import torch import torch
from packaging import version
from torch import nn from torch import nn
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -140,6 +141,10 @@ class FlaubertModel(XLMModel): ...@@ -140,6 +141,10 @@ class FlaubertModel(XLMModel):
super().__init__(config) super().__init__(config)
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -198,10 +203,16 @@ class FlaubertModel(XLMModel): ...@@ -198,10 +203,16 @@ class FlaubertModel(XLMModel):
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# isues similar to issue #5664
if position_ids is None: if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device) if hasattr(self, "position_ids"):
position_ids = position_ids.unsqueeze(0).expand((bs, slen)) position_ids = self.position_ids[:, :slen]
position_ids = position_ids.expand((bs, slen))
else:
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else: else:
assert position_ids.size() == (bs, slen) # (slen, bs) assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment