Unverified Commit 6dc0a849 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix weight tying in TF-ESM (#22839)

Fix weight tying in ESM
parent 3b61d289
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch ESM model.""" """ PyTorch ESM model."""
import os
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
self.lm_head = TFEsmLMHead(config, name="lm_head") self.lm_head = TFEsmLMHead(config, name="lm_head")
if config.tie_word_embeddings:
# Ensure word embeddings are built so that we actually have something to tie
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
self.esm.embeddings.word_embeddings.build((None, None))
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
...@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer): ...@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.decoder = Dense( self.decoder = None
config.vocab_size,
use_bias=False,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
)
self.config = config self.config = config
def build(self, input_shape): def build(self, input_shape):
super().build(input_shape) super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work # Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight # Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
def get_bias(self): def get_bias(self):
...@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer): ...@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
x = self.layer_norm(x) x = self.layer_norm(x)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = self.decoder(x) x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
x = x + self.bias
return x return x
......
...@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_save_load_after_resize_token_embeddings(self): def test_save_load_after_resize_token_embeddings(self):
pass pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class is TFEsmForMaskedLM:
# Output embedding test differs from the main test because they're a matrix, not a layer
name = model.get_bias()
assert isinstance(name, dict)
for k, v in name.items():
assert isinstance(v, tf.Variable)
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None
@require_tf @require_tf
class TFEsmModelIntegrationTest(unittest.TestCase): class TFEsmModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment