Unverified Commit 9ed80b00 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: TFBart embedding initialization (#19460)

* correct embedding init
parent b651efe5
...@@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Return: Return:
`tf.keras.layers.Embedding`: Resized Embedding layer. `tf.keras.layers.Embedding`: Resized Embedding layer.
""" """
# Get the initialization range for the embeddings
init_range = 0.02 # default value
potential_initialization_variable_names = [
"initializer_range", # most common
"initializer_factor", # e.g. T5
"init_std", # e.g BART
]
for var_name in potential_initialization_variable_names:
if hasattr(self.config, var_name):
init_range = getattr(self.config, var_name)
# Get a new (initialized) embeddings layer # Get a new (initialized) embeddings layer
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = tf.keras.layers.Embedding( new_embeddings = tf.keras.layers.Embedding(
input_dim=new_num_tokens, input_dim=new_num_tokens,
output_dim=old_embeddings.output_dim, output_dim=old_embeddings.output_dim,
embeddings_initializer=get_initializer(init_range), embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=init_range),
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0" name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
) )
new_embeddings(tf.constant([[0]])) new_embeddings(tf.constant([[0]]))
......
...@@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs): def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared") self.shared = tf.keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared",
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights) # Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment