Commit f45d70c9 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

#FunnelTransformer support AvgPool1D option.

PiperOrigin-RevId: 402882989
parent 24b8c1d1
......@@ -90,6 +90,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
pool_type: Pooling type. Choose from ['max', 'avg'].
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
......@@ -123,6 +124,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
output_dropout=0.1,
attention_dropout=0.1,
pool_type='max',
pool_stride=2,
unpool_length=0,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
......@@ -204,9 +206,16 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
if len(pool_stride) != num_layers:
raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride
# TODO(crickwu): explore tf.keras.layers.serialize method.
if pool_type == 'max':
pool_cls = tf.keras.layers.MaxPooling1D
elif pool_type == 'avg':
pool_cls = tf.keras.layers.AveragePooling1D
else:
raise ValueError('pool_type not supported.')
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides:
att_input_pool_layer = tf.keras.layers.MaxPooling1D(
att_input_pool_layer = pool_cls(
pool_size=layer_pool_stride,
strides=layer_pool_stride,
padding='same',
......@@ -232,6 +241,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'pool_type': pool_type,
'pool_stride': pool_stride,
'unpool_length': unpool_length,
}
......
......@@ -37,9 +37,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
super(FunnelTransformerEncoderTest, self).tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters(("mix", "mixed_float16", tf.float16),
("float32", "float32", tf.float32))
def test_network_creation(self, policy, pooled_dtype):
@parameterized.named_parameters(
("mix_max", "mixed_float16", tf.float16, "max"),
("float32_max", "float32", tf.float32, "max"),
("mix_avg", "mixed_float16", tf.float16, "avg"),
("float32_avg", "float32", tf.float32, "avg"))
def test_network_creation(self, policy, pooled_dtype, pool_type):
tf.keras.mixed_precision.set_global_policy(policy)
hidden_size = 32
......@@ -53,6 +56,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
pool_type=pool_type,
unpool_length=0)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -238,6 +242,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
embedding_width=16,
embedding_layer=None,
norm_first=False,
pool_type="max",
pool_stride=2,
unpool_length=0)
network = funnel_transformer.FunnelTransformerEncoder(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment