Commit a0cd17e0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457383160
parent a5bbb547
......@@ -19,6 +19,7 @@ import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -57,9 +58,9 @@ class GatedFeedforward(tf.keras.layers.Layer):
"""
def __init__(self,
intermediate_size,
intermediate_activation,
dropout,
inner_dim=768,
inner_activation=tf_utils.get_activation("gelu"),
dropout=0.0,
use_gate=True,
apply_output_layer_norm=True,
num_blocks=1,
......@@ -72,9 +73,12 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(GatedFeedforward, self).__init__(**kwargs)
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._dropout = dropout
self._use_gate = use_gate
self._num_blocks = num_blocks
......@@ -103,7 +107,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._intermediate_dense = []
self._intermediate_activation_layers = []
self._inner_activation_layers = []
self._gate_dense = []
self._output_dense = []
self._output_dropout = []
......@@ -118,7 +122,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._intermediate_dense.append(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
......@@ -126,14 +130,14 @@ class GatedFeedforward(tf.keras.layers.Layer):
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._intermediate_activation_layers.append(
self._inner_activation_layers.append(
tf.keras.layers.Activation(
self._intermediate_activation, dtype=activation_policy))
self._inner_activation, dtype=activation_policy))
if self._use_gate:
self._gate_dense.append(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
......@@ -164,10 +168,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
def get_config(self):
config = {
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"dropout":
self._dropout,
"use_gate":
......@@ -191,7 +195,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(GatedFeedforward, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
......@@ -199,7 +203,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
for i in range(self._num_blocks):
layer_input = layer_output
intermediate_output = self._intermediate_dense[i](layer_input)
intermediate_output = self._intermediate_activation_layers[i](
intermediate_output = self._inner_activation_layers[i](
intermediate_output)
if self._use_gate:
gated_linear = self._gate_dense[i](layer_input)
......
......@@ -44,8 +44,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=128,
intermediate_activation="relu",
inner_dim=128,
inner_activation="relu",
dropout=0.1,
use_gate=use_gate,
num_blocks=num_blocks,
......@@ -76,8 +76,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
inner_dim=16,
inner_activation="relu",
dropout=0.1,
use_gate=use_gate,
num_blocks=num_blocks,
......@@ -104,8 +104,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_serialize_deserialize(self):
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
inner_dim=16,
inner_activation="relu",
dropout=0.1,
use_gate=False,
num_blocks=4,
......
......@@ -76,7 +76,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment