Unverified Commit 161a6461 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF template (#9234)

parent 5a8a4eb1
...@@ -310,18 +310,22 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer): ...@@ -310,18 +310,22 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.experimental.EinsumDense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act) self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(inputs=hidden_states)
return hidden_states return hidden_states
...@@ -331,16 +335,20 @@ class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer): ...@@ -331,16 +335,20 @@ class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.experimental.EinsumDense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False): def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment