Unverified Commit eb1493b1 authored by Clémentine Fourrier's avatar Clémentine Fourrier Committed by GitHub
Browse files

Fix #17893, removed dead code (#17917)

* Removed dead position_id code, fix #17893

* Removed unused var

* Now ignores removed (dead) dict key for backward comp
parent fbc7598b
...@@ -447,8 +447,6 @@ class LongformerEmbeddings(nn.Module): ...@@ -447,8 +447,6 @@ class LongformerEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -469,13 +467,8 @@ class LongformerEmbeddings(nn.Module): ...@@ -469,13 +467,8 @@ class LongformerEmbeddings(nn.Module):
else: else:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=position_ids.device)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
...@@ -1392,7 +1385,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1392,7 +1385,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment