Commit f45d70c9 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

#FunnelTransformer support AvgPool1D option.

PiperOrigin-RevId: 402882989
parent 24b8c1d1
...@@ -90,6 +90,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -90,6 +90,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers within attention_dropout: The dropout rate to use for the attention layers within
the transformer layers. the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg'].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size. sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers. If set to list, the number of elements needs to match num_layers.
...@@ -123,6 +124,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -123,6 +124,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True), inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1, output_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
pool_type='max',
pool_stride=2, pool_stride=2,
unpool_length=0, unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
...@@ -204,9 +206,16 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -204,9 +206,16 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
if len(pool_stride) != num_layers: if len(pool_stride) != num_layers:
raise ValueError('Lengths of pool_stride and num_layers are not equal.') raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride pool_strides = pool_stride
# TODO(crickwu): explore tf.keras.layers.serialize method.
if pool_type == 'max':
pool_cls = tf.keras.layers.MaxPooling1D
elif pool_type == 'avg':
pool_cls = tf.keras.layers.AveragePooling1D
else:
raise ValueError('pool_type not supported.')
self._att_input_pool_layers = [] self._att_input_pool_layers = []
for layer_pool_stride in pool_strides: for layer_pool_stride in pool_strides:
att_input_pool_layer = tf.keras.layers.MaxPooling1D( att_input_pool_layer = pool_cls(
pool_size=layer_pool_stride, pool_size=layer_pool_stride,
strides=layer_pool_stride, strides=layer_pool_stride,
padding='same', padding='same',
...@@ -232,6 +241,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -232,6 +241,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'embedding_layer': embedding_layer, 'embedding_layer': embedding_layer,
'norm_first': norm_first, 'norm_first': norm_first,
'pool_type': pool_type,
'pool_stride': pool_stride, 'pool_stride': pool_stride,
'unpool_length': unpool_length, 'unpool_length': unpool_length,
} }
......
...@@ -37,9 +37,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,9 +37,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
super(FunnelTransformerEncoderTest, self).tearDown() super(FunnelTransformerEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32") tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters(("mix", "mixed_float16", tf.float16), @parameterized.named_parameters(
("float32", "float32", tf.float32)) ("mix_max", "mixed_float16", tf.float16, "max"),
def test_network_creation(self, policy, pooled_dtype): ("float32_max", "float32", tf.float32, "max"),
("mix_avg", "mixed_float16", tf.float16, "avg"),
("float32_avg", "float32", tf.float32, "avg"))
def test_network_creation(self, policy, pooled_dtype, pool_type):
tf.keras.mixed_precision.set_global_policy(policy) tf.keras.mixed_precision.set_global_policy(policy)
hidden_size = 32 hidden_size = 32
...@@ -53,6 +56,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -53,6 +56,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_attention_heads=2, num_attention_heads=2,
num_layers=num_layers, num_layers=num_layers,
pool_stride=pool_stride, pool_stride=pool_stride,
pool_type=pool_type,
unpool_length=0) unpool_length=0)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -238,6 +242,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -238,6 +242,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
embedding_width=16, embedding_width=16,
embedding_layer=None, embedding_layer=None,
norm_first=False, norm_first=False,
pool_type="max",
pool_stride=2, pool_stride=2,
unpool_length=0) unpool_length=0)
network = funnel_transformer.FunnelTransformerEncoder(**kwargs) network = funnel_transformer.FunnelTransformerEncoder(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment