Commit 10ee28dd authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 431722553
parent aa870ff4
......@@ -226,6 +226,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero.
with_dense_inputs: Whether to accept dense embeddings as the input.
"""
def __init__(
......@@ -413,6 +414,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3:
word_ids, mask, type_ids = inputs
dense_inputs = None
dense_mask = None
dense_type_ids = None
elif len(inputs) == 6:
word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d.' %
(self.__class__, len(inputs)))
......@@ -420,10 +426,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence begin so unpool_len can control
# embedding not being pooled.
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1)
type_ids = tf.concat([dense_type_ids, type_ids], axis=1)
mask = tf.concat([dense_mask, mask], axis=1)
# absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
......
......@@ -101,6 +101,55 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_network_creation_dense(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
pool_type = "avg"
hidden_size = 32
sequence_length = 21
dense_sequence_length = 3
pool_stride = 2
num_layers = 3
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
pool_type=pool_type,
max_sequence_length=sequence_length + dense_sequence_length,
unpool_length=0,
transformer_cls="TransformerEncoderBlock")
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_layers)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer.
# For pool_type = max or avg,
# this configuration gives each layer of seq length: 24->12->6->3.
expected_data_shape = [None, 3, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
def test_invalid_stride_and_num_layers(self):
hidden_size = 32
num_layers = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment