Commit 6af14369 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Allow ReZero to take 3 inputs that can be used for compressing sequence length.

PiperOrigin-RevId: 413261350
parent d86d4aa6
......@@ -80,8 +80,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._use_layer_norm = use_layer_norm
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
......@@ -198,19 +205,30 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self._rezero_a.assign(0.)
def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, attention_mask = (inputs, None)
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
......
......@@ -124,6 +124,20 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
def test_separate_qkv(self):
test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=2,
intermediate_size=128,
intermediate_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask]
output = test_layer(inputs)
self.assertEqual(output.shape, q_tensor.shape)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment