Commit 2eb655c4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475932957
parent 370ccbdc
...@@ -135,7 +135,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -135,7 +135,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
# Deprecation warning. # Deprecation warning.
if output_range is not None: if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`." logging.warning("`output_range` is available as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.") "The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
......
...@@ -116,11 +116,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -116,11 +116,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
new_layer = transformer_cls( new_layer = transformer_cls(
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu')
output_range=1) _ = new_layer([input_data, mask_data], output_range=1)
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data], output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
...@@ -147,11 +146,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -147,11 +146,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu',
output_range=1,
norm_first=True) norm_first=True)
_ = new_layer(input_data) _ = new_layer(input_data, output_range=1)
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer(input_data) new_output_tensor = new_layer(input_data, output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
...@@ -177,11 +175,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase): ...@@ -177,11 +175,10 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
num_attention_heads=10, num_attention_heads=10,
inner_dim=2048, inner_dim=2048,
inner_activation='relu', inner_activation='relu',
output_range=1,
norm_first=True) norm_first=True)
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data], output_range=1)
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data], output_range=1)
self.assertAllClose( self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003) new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
...@@ -291,7 +288,6 @@ class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase ...@@ -291,7 +288,6 @@ class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase
num_attention_heads=2, num_attention_heads=2,
inner_dim=128, inner_dim=128,
inner_activation='relu', inner_activation='relu',
output_range=output_range,
norm_first=True) norm_first=True)
# Forward path. # Forward path.
q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32) q_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
...@@ -299,7 +295,7 @@ class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase ...@@ -299,7 +295,7 @@ class TransformerEncoderBlockLayerTestWithoutParams(keras_parameterized.TestCase
dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32) dummy_mask = tf.zeros([2, 4, 8], dtype=tf.float32)
inputs = [q_tensor, kv_tensor, dummy_mask] inputs = [q_tensor, kv_tensor, dummy_mask]
with self.assertRaises(tf.errors.InvalidArgumentError): with self.assertRaises(tf.errors.InvalidArgumentError):
test_layer(inputs) test_layer(inputs, output_range=output_range)
test_layer = TransformerEncoderBlock( test_layer = TransformerEncoderBlock(
num_attention_heads=2, num_attention_heads=2,
......
...@@ -256,6 +256,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -256,6 +256,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
share_rezero: bool = False, share_rezero: bool = False,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if output_range is not None:
logging.warning('`output_range` is available as an argument for `call()`.'
'The `output_range` as __init__ argument is deprecated.')
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -306,6 +311,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -306,6 +311,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# Will raise an error if the string is not supported. # Will raise an error if the string is not supported.
if isinstance(transformer_cls, str): if isinstance(transformer_cls, str):
transformer_cls = _str2transformer_cls[transformer_cls] transformer_cls = _str2transformer_cls[transformer_cls]
self._num_layers = num_layers
for i in range(num_layers): for i in range(num_layers):
layer = transformer_cls( layer = transformer_cls(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -316,7 +322,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -316,7 +322,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero, share_rezero=share_rezero,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
...@@ -407,7 +412,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -407,7 +412,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32)) input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs): def call(self, inputs, output_range: Optional[tf.Tensor] = None):
# inputs are [word_ids, mask, type_ids] # inputs are [word_ids, mask, type_ids]
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
logging.warning('List inputs to %s are discouraged.', self.__class__) logging.warning('List inputs to %s are discouraged.', self.__class__)
...@@ -477,7 +482,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -477,7 +482,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs), dtype=pooled_inputs.dtype), pooled_inputs),
axis=1) axis=1)
x = layer([query_inputs, x, attention_mask]) x = layer([query_inputs, x, attention_mask],
output_range=output_range if i == self._num_layers -
1 else None)
# Pools the corresponding attention_mask. # Pools the corresponding attention_mask.
if i < len(self._transformer_layers) - 1: if i < len(self._transformer_layers) - 1:
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
...@@ -496,9 +503,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -496,9 +503,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
pooling_transforms) pooling_transforms)
for i, layer in enumerate(self._transformer_layers): for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i] attention_mask = attention_masks[i]
transformer_output_range = None
if i == self._num_layers - 1:
transformer_output_range = output_range
# Bypass no pooling cases. # Bypass no pooling cases.
if self._pool_strides[i] == 1: if self._pool_strides[i] == 1:
x = layer([x, x, attention_mask]) x = layer([x, x, attention_mask],
output_range=transformer_output_range)
else: else:
pooled_inputs = tf.einsum( pooled_inputs = tf.einsum(
'BFD,FT->BTD', 'BFD,FT->BTD',
...@@ -510,7 +521,8 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -510,7 +521,8 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
dtype=pooled_inputs.dtype), pooled_inputs), dtype=pooled_inputs.dtype), pooled_inputs),
axis=1) axis=1)
x = layer([query_inputs, x, attention_mask]) x = layer([query_inputs, x, attention_mask],
output_range=transformer_output_range)
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
......
...@@ -229,14 +229,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -229,14 +229,14 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
output_range=output_range,
pool_stride=pool_stride, pool_stride=pool_stride,
unpool_length=unpool_length) unpool_length=unpool_length)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network([word_ids, mask, type_ids]) dict_outputs = test_network([word_ids, mask, type_ids],
output_range=output_range)
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
......
...@@ -129,6 +129,10 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -129,6 +129,10 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
attention_dropout = kwargs.pop('attention_dropout_rate') attention_dropout = kwargs.pop('attention_dropout_rate')
super().__init__(**kwargs) super().__init__(**kwargs)
if output_range is not None:
logging.warning('`output_range` is available as an argument for `call()`.'
'The `output_range` as __init__ argument is deprecated.')
activation = tf.keras.activations.get(inner_activation) activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -204,7 +208,6 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -204,7 +208,6 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
output_dropout=output_dropout, output_dropout=output_dropout,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer), kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
...@@ -254,7 +257,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -254,7 +257,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32)) input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs): def call(self, inputs, output_range: Optional[tf.Tensor] = None):
if isinstance(inputs, dict): if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids') word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask') mask = inputs.get('input_mask')
...@@ -303,8 +306,11 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -303,8 +306,11 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
# 4. Finally, all tokens go through the last layer. # 4. Finally, all tokens go through the last layer.
# Step 1. # Step 1.
for layer in self._transformer_layers[:self._num_layers // 2 - 1]: for i, layer in enumerate(self._transformer_layers[:self._num_layers // 2 -
x = layer([x, attention_mask]) 1]):
x = layer([x, attention_mask],
output_range=output_range if i == self._num_layers -
1 else None)
encoder_outputs.append(x) encoder_outputs.append(x)
# Step 2. # Step 2.
...@@ -322,12 +328,17 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -322,12 +328,17 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
# Then, call transformer layer with cross attention. # Then, call transformer layer with cross attention.
x_selected = self._transformer_layers[self._num_layers // 2 - 1]( x_selected = self._transformer_layers[self._num_layers // 2 - 1](
[x_selected, x_all, attention_mask_token_pass]) [x_selected, x_all, attention_mask_token_pass],
output_range=output_range if self._num_layers // 2 -
1 == self._num_layers - 1 else None)
encoder_outputs.append(x_selected) encoder_outputs.append(x_selected)
# Step 3. # Step 3.
for layer in self._transformer_layers[self._num_layers // 2:-1]: for i, layer in enumerate(self._transformer_layers[self._num_layers //
x_selected = layer([x_selected, attention_mask_token_drop]) 2:-1]):
x_selected = layer([x_selected, attention_mask_token_drop],
output_range=output_range if i == self._num_layers - 1
else None)
encoder_outputs.append(x_selected) encoder_outputs.append(x_selected)
# Step 4. # Step 4.
...@@ -339,7 +350,8 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -339,7 +350,8 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
x = tf.gather(x, reverse_indices, batch_dims=1, axis=1) x = tf.gather(x, reverse_indices, batch_dims=1, axis=1)
# Then, call transformer layer with all tokens. # Then, call transformer layer with all tokens.
x = self._transformer_layers[-1]([x, attention_mask]) x = self._transformer_layers[-1]([x, attention_mask],
output_range=output_range)
encoder_outputs.append(x) encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1] last_encoder_output = encoder_outputs[-1]
......
...@@ -150,7 +150,6 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase): ...@@ -150,7 +150,6 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
output_range=output_range,
dict_outputs=True, dict_outputs=True,
token_keep_k=2, token_keep_k=2,
token_allow_list=(), token_allow_list=(),
...@@ -160,7 +159,8 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase): ...@@ -160,7 +159,8 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network( dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)) dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids),
output_range=output_range)
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
...@@ -349,7 +349,6 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase): ...@@ -349,7 +349,6 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
output_range=output_range,
token_keep_k=2, token_keep_k=2,
token_allow_list=(), token_allow_list=(),
token_deny_list=()) token_deny_list=())
...@@ -358,7 +357,8 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase): ...@@ -358,7 +357,8 @@ class TokenDropBertEncoderTest(keras_parameterized.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dict_outputs = test_network( dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)) dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids),
output_range=output_range)
data = dict_outputs["sequence_output"] data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"] pooled = dict_outputs["pooled_output"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment