Commit a7e60974 authored by Xin Wang's avatar Xin Wang Committed by A. Unique TensorFlower
Browse files

Added block diagonal feedforward layer.

This layer replaces the weight matrix of the output_dense layer with a block diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can be added optionally to improve layer expressibility.

PiperOrigin-RevId: 418828099
parent 60f6d6cb
......@@ -20,6 +20,7 @@ They can be used to assemble new `tf.keras` layers or models.
from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
from official.nlp.modeling.layers.bigbird_attention import BigBirdMasks
from official.nlp.modeling.layers.block_diag_feedforward import BlockDiagFeedforward
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based gated feedforward layer."""
# pylint: disable=g-classes-have-attributes
from typing import Optional
import tensorflow as tf
class BlockDiagFeedforward(tf.keras.layers.Layer):
"""Block diagonal feedforward layer.
This layer replaces the weight matrix of the output_dense layer with a block
diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can
be added optionally to improve layer expressibility.
Args:
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout: Dropout probability for the output dropout.
num_blocks: The number of blocks for the block diagonal matrix of the
output_dense layer.
apply_mixing: Apply linear mixing if True.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(
self,
intermediate_size: int,
intermediate_activation: str,
dropout: float,
num_blocks: int = 1,
apply_mixing: bool = True,
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
kernel_constraint: Optional[tf.keras.constraints.Constraint] = None,
bias_constraint: Optional[tf.keras.constraints.Constraint] = None,
**kwargs): # pylint: disable=g-doc-args
super(BlockDiagFeedforward, self).__init__(**kwargs)
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._dropout = dropout
self._num_blocks = num_blocks
self._apply_mixing = apply_mixing
if intermediate_size % num_blocks != 0:
raise ValueError("Intermediate_size (%d) isn't a multiple of num_blocks "
"(%d)." % (intermediate_size, num_blocks))
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, input_shape):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cde->abde",
output_shape=(None, self._num_blocks,
self._intermediate_size // self._num_blocks),
bias_axes="de",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abde,deo->abdo",
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
bias_axes="do",
name="output",
**common_kwargs)
if self._apply_mixing:
self._output_mixing = tf.keras.layers.experimental.EinsumDense(
"abdo,de->abeo",
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
name="output_mixing",
**common_kwargs)
self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout)
def get_config(self):
config = {
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout":
self._dropout,
"num_blocks":
self._num_blocks,
"apply_mixing":
self._apply_mixing,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(BlockDiagFeedforward, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
intermediate_output = self._intermediate_dense(inputs)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
if self._apply_mixing:
layer_output = self._output_mixing(layer_output)
layer_output = self._output_reshape(layer_output)
layer_output = self._output_dropout(layer_output)
return layer_output
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Keras-based gated feedforward layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import block_diag_feedforward
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class BlockDiagFeedforwardTest(keras_parameterized.TestCase):
def tearDown(self):
super(BlockDiagFeedforwardTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.parameters(
(1, True, "float32"),
(1, True, "mixed_float16"),
(1, False, "float32"),
(1, False, "mixed_float16"),
(2, True, "float32"),
(2, True, "mixed_float16"),
(2, False, "float32"),
(2, False, "mixed_float16"),
)
def test_layer_creation(self, num_blocks, apply_mixing, dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=128,
intermediate_activation="relu",
dropout=0.1,
num_blocks=num_blocks,
apply_mixing=apply_mixing,
kernel_initializer="glorot_uniform",
bias_initializer="zeros")
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
sequence_length = 64
width = 128
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
@parameterized.parameters(
(1, True, "float32"),
(1, True, "mixed_float16"),
(1, False, "float32"),
(1, False, "mixed_float16"),
(2, True, "float32"),
(2, True, "mixed_float16"),
(2, False, "float32"),
(2, False, "mixed_float16"),
)
def test_layer_invocation(self, num_blocks, apply_mixing, dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
dropout=0.1,
num_blocks=num_blocks,
apply_mixing=apply_mixing,
kernel_initializer="glorot_uniform",
bias_initializer="zeros")
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
sequence_length = 16
width = 32
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
output_tensor = test_layer(data_tensor)
# Create a model from the test layer.
model = tf.keras.Model(data_tensor, output_tensor)
# Invoke the model on test data.
batch_size = 6
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, width))
output_data = model.predict(input_data)
self.assertEqual(output_data.shape, (batch_size, sequence_length, width))
def test_get_config(self):
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
dropout=0.1,
num_blocks=2,
apply_mixing=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros")
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
new_layer = block_diag_feedforward.BlockDiagFeedforward.from_config(
test_layer.get_config())
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment