Commit a0cd17e0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457383160
parent a5bbb547
...@@ -19,6 +19,7 @@ import gin ...@@ -19,6 +19,7 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -57,9 +58,9 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -57,9 +58,9 @@ class GatedFeedforward(tf.keras.layers.Layer):
""" """
def __init__(self, def __init__(self,
intermediate_size, inner_dim=768,
intermediate_activation, inner_activation=tf_utils.get_activation("gelu"),
dropout, dropout=0.0,
use_gate=True, use_gate=True,
apply_output_layer_norm=True, apply_output_layer_norm=True,
num_blocks=1, num_blocks=1,
...@@ -72,9 +73,12 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -72,9 +73,12 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(GatedFeedforward, self).__init__(**kwargs) inner_dim = kwargs.pop("intermediate_size", inner_dim)
self._intermediate_size = intermediate_size inner_activation = kwargs.pop("intermediate_activation", inner_activation)
self._intermediate_activation = intermediate_activation util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._dropout = dropout self._dropout = dropout
self._use_gate = use_gate self._use_gate = use_gate
self._num_blocks = num_blocks self._num_blocks = num_blocks
...@@ -103,7 +107,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -103,7 +107,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
self._intermediate_dense = [] self._intermediate_dense = []
self._intermediate_activation_layers = [] self._inner_activation_layers = []
self._gate_dense = [] self._gate_dense = []
self._output_dense = [] self._output_dense = []
self._output_dropout = [] self._output_dropout = []
...@@ -118,7 +122,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -118,7 +122,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._intermediate_dense.append( self._intermediate_dense.append(
tf.keras.layers.EinsumDense( tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="intermediate_%d" % i, name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
...@@ -126,14 +130,14 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -126,14 +130,14 @@ class GatedFeedforward(tf.keras.layers.Layer):
bias_initializer=tf_utils.clone_initializer( bias_initializer=tf_utils.clone_initializer(
self._bias_initializer), self._bias_initializer),
**common_kwargs)) **common_kwargs))
self._intermediate_activation_layers.append( self._inner_activation_layers.append(
tf.keras.layers.Activation( tf.keras.layers.Activation(
self._intermediate_activation, dtype=activation_policy)) self._inner_activation, dtype=activation_policy))
if self._use_gate: if self._use_gate:
self._gate_dense.append( self._gate_dense.append(
tf.keras.layers.EinsumDense( tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
name="gate_%d" % i, name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
...@@ -164,10 +168,10 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -164,10 +168,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
def get_config(self): def get_config(self):
config = { config = {
"intermediate_size": "inner_dim":
self._intermediate_size, self._inner_dim,
"intermediate_activation": "inner_activation":
self._intermediate_activation, self._inner_activation,
"dropout": "dropout":
self._dropout, self._dropout,
"use_gate": "use_gate":
...@@ -191,7 +195,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -191,7 +195,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint)
} }
base_config = super(GatedFeedforward, self).get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
...@@ -199,7 +203,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -199,7 +203,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
for i in range(self._num_blocks): for i in range(self._num_blocks):
layer_input = layer_output layer_input = layer_output
intermediate_output = self._intermediate_dense[i](layer_input) intermediate_output = self._intermediate_dense[i](layer_input)
intermediate_output = self._intermediate_activation_layers[i]( intermediate_output = self._inner_activation_layers[i](
intermediate_output) intermediate_output)
if self._use_gate: if self._use_gate:
gated_linear = self._gate_dense[i](layer_input) gated_linear = self._gate_dense[i](layer_input)
......
...@@ -44,8 +44,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -44,8 +44,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype): def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype):
tf.keras.mixed_precision.set_global_policy(dtype) tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict( kwargs = dict(
intermediate_size=128, inner_dim=128,
intermediate_activation="relu", inner_activation="relu",
dropout=0.1, dropout=0.1,
use_gate=use_gate, use_gate=use_gate,
num_blocks=num_blocks, num_blocks=num_blocks,
...@@ -76,8 +76,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -76,8 +76,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
dtype): dtype):
tf.keras.mixed_precision.set_global_policy(dtype) tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict( kwargs = dict(
intermediate_size=16, inner_dim=16,
intermediate_activation="relu", inner_activation="relu",
dropout=0.1, dropout=0.1,
use_gate=use_gate, use_gate=use_gate,
num_blocks=num_blocks, num_blocks=num_blocks,
...@@ -104,8 +104,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase): ...@@ -104,8 +104,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
kwargs = dict( kwargs = dict(
intermediate_size=16, inner_dim=16,
intermediate_activation="relu", inner_activation="relu",
dropout=0.1, dropout=0.1,
use_gate=False, use_gate=False,
num_blocks=4, num_blocks=4,
......
...@@ -76,7 +76,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -76,7 +76,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate) attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate) dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim) inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation) inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs) util.filter_kwargs(kwargs)
super().__init__(**kwargs) super().__init__(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment