Commit 86df41f7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

#KerasNLP Update TransformerEncoderBlock to support Q, KV as two input streams.

PiperOrigin-RevId: 350170448
parent ece64b24
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -85,7 +85,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -85,7 +85,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
kernel. kernel.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
super(TransformerEncoderBlock, self).__init__(**kwargs) super().__init__(**kwargs)
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._inner_dim = inner_dim self._inner_dim = inner_dim
...@@ -111,23 +111,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -111,23 +111,18 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = self._kernel_initializer self._attention_initializer = self._kernel_initializer
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = tf.TensorShape(input_tensor) input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
if len(input_tensor_shape.as_list()) != 3: if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerEncoderBlock expects a three-dimensional " raise ValueError("TransformerEncoderBlock expects a three-dimensional "
"input of shape [batch, sequence, width].") "input of shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape hidden_size = input_tensor_shape[-1]
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to "
"TransformerEncoderBlock, the mask tensor must be of "
"shape [batch, sequence_length, sequence_length] "
"(here %s). Got a mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0: if hidden_size % self._num_heads != 0:
raise ValueError( raise ValueError(
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
...@@ -234,15 +229,38 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -234,15 +229,38 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: """Transformer self-attention encoder block call.
input_tensor, attention_mask = inputs
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An ouput tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range: if self._output_range:
if self._norm_first: if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :] source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor) input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
...@@ -250,10 +268,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -250,10 +268,14 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if self._norm_first: if self._norm_first:
source_tensor = input_tensor source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor) input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask) query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
attention_output = source_tensor + attention_output attention_output = source_tensor + attention_output
......
...@@ -55,18 +55,6 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -55,18 +55,6 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input. # The default output of a transformer layer should be the same as the input.
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list()) self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
def test_layer_creation_with_incorrect_mask_fails(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu')
sequence_length = 21
width = 80
# Create a 3-dimensional input (the first dimension is implicit).
data_tensor = tf.keras.Input(shape=(sequence_length, width))
# Create a 2-dimensional input (the first dimension is implicit).
mask_tensor = tf.keras.Input(shape=(sequence_length, sequence_length - 3))
with self.assertRaisesRegex(ValueError, 'When passing a mask tensor.*'):
_ = test_layer([data_tensor, mask_tensor])
def test_layer_invocation(self, transformer_cls): def test_layer_invocation(self, transformer_cls):
test_layer = transformer_cls( test_layer = transformer_cls(
num_attention_heads=10, inner_dim=2048, inner_activation='relu') num_attention_heads=10, inner_dim=2048, inner_activation='relu')
...@@ -249,6 +237,20 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -249,6 +237,20 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
def test_separate_qkv(self, transformer_cls):
test_layer = transformer_cls(
num_attention_heads=2,
inner_dim=128,
inner_activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
kv_tensor = tf.zeros([2, 8, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask]
output = test_layer(inputs)
self.assertEqual(output.shape, q_tensor.shape)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase): class TransformerArgumentTest(keras_parameterized.TestCase):
......
...@@ -77,7 +77,7 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock): ...@@ -77,7 +77,7 @@ class Transformer(keras_nlp.layers.TransformerEncoderBlock):
intermediate_dropout=0.0, intermediate_dropout=0.0,
attention_initializer=None, attention_initializer=None,
**kwargs): **kwargs):
super(Transformer, self).__init__( super().__init__(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=intermediate_size, inner_dim=intermediate_size,
inner_activation=intermediate_activation, inner_activation=intermediate_activation,
...@@ -105,7 +105,7 @@ class CompiledTransformer(Transformer): ...@@ -105,7 +105,7 @@ class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True) @tf_function_if_eager(experimental_compile=True)
def call(self, inputs): def call(self, inputs):
return super(CompiledTransformer, self).call(inputs) return super().call(inputs)
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2021 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment