Unverified Commit 5600e6f3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Hardcode GELU as the intermediate activation for ESM (#22892)

* Hardcode GELU as the intermediate activation for ESM

* Sneak a quick fix to the weight tying in too

* Make the call to gelu explicit
parent 874c7caf
......@@ -37,7 +37,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
get_tf_activation,
shape_list,
unpack_inputs,
)
......@@ -476,24 +475,19 @@ class TFEsmAttention(Layer):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Esm
class TFEsmIntermediate(tf.keras.layers.Layer):
def __init__(self, config: EsmConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
units=config.intermediate_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = tf.nn.gelu(hidden_states)
return hidden_states
......@@ -1216,23 +1210,21 @@ class TFEsmLMHead(Layer):
)
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.decoder = None
if config.tie_word_embeddings:
self.decoder = None
else:
self.decoder = Dense(
config.vocab_size,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
use_bias=False,
)
self.config = config
def build(self, input_shape):
super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
def get_bias(self):
......@@ -1244,7 +1236,10 @@ class TFEsmLMHead(Layer):
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
if self.config.tie_word_embeddings:
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
else:
x = self.decoder(x) + self.bias
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment