Unverified Commit 8a6928e2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: correct TFBart embeddings weights name when load_weight_prefix is passed (#18993)

parent c126a239
......@@ -16,6 +16,7 @@
import random
from contextlib import nullcontext
from typing import Optional, Tuple, Union
import numpy as np
......@@ -748,7 +749,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
......@@ -936,7 +945,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
......@@ -1032,8 +1049,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared")
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment