Unverified Commit 894db670 authored by mingruimingrui's avatar mingruimingrui Committed by GitHub
Browse files

Bugfix: Removal of padding_idx in BartLearnedPositionalEmbedding (#10200)



* Assumption of padding_idx <2 might not stand

* Use offset instead of 2

* Fix with black

* Change behavior to warning instead for backward compatibility.

* Fix with black

* Remove warning

* Make padding_idx non-required

* padding_idx fix for blenderbot

* padding_idx fix for blenderbot_small

* padding_idx fix for led

* padding_idx fix for mbart

* Remove extra whitespaces

* padding_idx fix for template

* Fix padding_idx passed to nn.Embedding mistake

* Fixed padding_idx passed to positional embedding in template

* Remove padding_idx from pytorch learned positional embeddings

* Remove accidentally added quotes

* Remove padding_idx from tf learned positional embeddings

* Remove zeroing of weights in __init__
Co-authored-by: default avatarWang Ming Rui <mingrui.wang@C02CJTUYMD6M.local>
parent 55fe80d0
......@@ -108,12 +108,11 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -673,7 +672,6 @@ class BartEncoder(BartPretrainedModel):
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
......@@ -836,7 +834,6 @@ class BartDecoder(BartPretrainedModel):
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......
......@@ -113,8 +113,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = 2
......@@ -632,7 +631,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -793,7 +791,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
......
......@@ -112,9 +112,8 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -635,7 +634,6 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
......@@ -800,7 +798,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
......
......@@ -118,8 +118,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
......@@ -629,7 +628,6 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -797,7 +795,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
......
......@@ -110,9 +110,8 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -636,7 +635,6 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([BlenderbotSmallEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
......@@ -800,7 +798,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......
......@@ -117,8 +117,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
......@@ -634,7 +633,6 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -802,7 +800,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
......
......@@ -112,9 +112,8 @@ class LEDLearnedPositionalEmbedding(nn.Embedding):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -1622,7 +1621,6 @@ class LEDEncoder(LEDPreTrainedModel):
self.embed_positions = LEDLearnedPositionalEmbedding(
self.max_source_positions,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
......@@ -1891,7 +1889,6 @@ class LEDDecoder(LEDPreTrainedModel):
self.embed_positions = LEDLearnedPositionalEmbedding(
self.max_target_positions,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......
......@@ -108,8 +108,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
......@@ -1612,7 +1611,6 @@ class TFLEDEncoder(tf.keras.layers.Layer):
self.embed_positions = TFLEDLearnedPositionalEmbedding(
config.max_encoder_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -1865,7 +1863,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
self.embed_positions = TFLEDLearnedPositionalEmbedding(
config.max_decoder_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
......
......@@ -114,12 +114,11 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
def __init__(self, num_embeddings: int, embedding_dim: int):
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -678,7 +677,6 @@ class MBartEncoder(MBartPreTrainedModel):
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
......@@ -844,7 +842,6 @@ class MBartDecoder(MBartPreTrainedModel):
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......
......@@ -115,8 +115,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = 2
......@@ -636,7 +635,6 @@ class TFMBartEncoder(tf.keras.layers.Layer):
self.embed_positions = TFMBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TFMBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -806,7 +804,6 @@ class TFMBartDecoder(tf.keras.layers.Layer):
self.embed_positions = TFMBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
......
......@@ -1565,8 +1565,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
assert padding_idx is not None, "padding_idx cannot be None"
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
......@@ -2017,7 +2016,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
......@@ -2160,7 +2158,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
......
......@@ -1616,9 +1616,8 @@ class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embeddin
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
......@@ -2172,7 +2171,6 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
)
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
......@@ -2335,7 +2333,6 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment