"tests/vscode:/vscode.git/clone" did not exist on "d185b5ed5f23c5912918ee81881a3c03f9359523"
Unverified Commit 0eabab09 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: final bias as a layer in seq2seq models (replicate TFMarian fix) (#18903)

parent 2b9513fd
......@@ -1251,6 +1251,23 @@ class TFBartModel(TFBartPretrainedModel):
)
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING,
......@@ -1268,9 +1285,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -1357,7 +1375,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -1239,6 +1239,24 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The BLENDERBOT Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_START_DOCSTRING,
......@@ -1254,9 +1272,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
self.model = TFBlenderbotMainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -1358,7 +1377,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -1226,6 +1226,24 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_SMALL_START_DOCSTRING,
......@@ -1241,9 +1259,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
self.model = TFBlenderbotSmallMainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -1330,7 +1349,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -2316,6 +2316,24 @@ class TFLEDModel(TFLEDPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The LED Model with a language modeling head. Can be used for summarization.",
LED_START_DOCSTRING,
......@@ -2331,9 +2349,10 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
self.led = TFLEDMainLayer(config, name="led")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
# TODO (Joao): investigate why LED has numerical issues in XLA generate
self.supports_xla_generation = False
......@@ -2423,7 +2442,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
training=training,
)
lm_logits = self.led.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -1269,6 +1269,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
......
......@@ -1266,6 +1266,24 @@ class TFMBartModel(TFMBartPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The MBART Model with a language modeling head. Can be used for summarization.",
MBART_START_DOCSTRING,
......@@ -1281,9 +1299,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
self.model = TFMBartMainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -1368,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -1278,6 +1278,24 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The PEGASUS Model with a language modeling head. Can be used for summarization.",
PEGASUS_START_DOCSTRING,
......@@ -1293,9 +1311,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
self.model = TFPegasusMainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -1382,7 +1401,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
......@@ -2806,6 +2806,24 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization.",
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
......@@ -2822,9 +2840,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
self.model._set_save_spec(inputs=self.serving.input_signature)
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.final_logits_bias = self.add_weight(
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
def get_decoder(self):
return self.model.decoder
......@@ -2911,7 +2930,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
training=training
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment