Unverified Commit 5600e6f3 authored by Matt's avatar Matt Committed by GitHub
Browse files

Hardcode GELU as the intermediate activation for ESM (#22892)

* Hardcode GELU as the intermediate activation for ESM

* Sneak a quick fix to the weight tying in too

* Make the call to gelu explicit
parent 874c7caf
...@@ -37,7 +37,6 @@ from ...modeling_tf_utils import ( ...@@ -37,7 +37,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
get_tf_activation,
shape_list, shape_list,
unpack_inputs, unpack_inputs,
) )
...@@ -476,24 +475,19 @@ class TFEsmAttention(Layer): ...@@ -476,24 +475,19 @@ class TFEsmAttention(Layer):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Esm
class TFEsmIntermediate(tf.keras.layers.Layer): class TFEsmIntermediate(tf.keras.layers.Layer):
def __init__(self, config: EsmConfig, **kwargs): def __init__(self, config: EsmConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" units=config.intermediate_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor: def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = tf.nn.gelu(hidden_states)
return hidden_states return hidden_states
...@@ -1216,23 +1210,21 @@ class TFEsmLMHead(Layer): ...@@ -1216,23 +1210,21 @@ class TFEsmLMHead(Layer):
) )
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
if config.tie_word_embeddings:
self.decoder = None self.decoder = None
else:
self.decoder = Dense(
config.vocab_size,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
use_bias=False,
)
self.config = config self.config = config
def build(self, input_shape): def build(self, input_shape):
super().build(input_shape) super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work # Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight # Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
def get_bias(self): def get_bias(self):
...@@ -1244,7 +1236,10 @@ class TFEsmLMHead(Layer): ...@@ -1244,7 +1236,10 @@ class TFEsmLMHead(Layer):
x = self.layer_norm(x) x = self.layer_norm(x)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias if self.config.tie_word_embeddings:
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
else:
x = self.decoder(x) + self.bias
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment