Commit 10ee28dd authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 431722553
parent aa870ff4
...@@ -226,6 +226,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -226,6 +226,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
funnel encoder relies on. funnel encoder relies on.
share_rezero: bool. Whether to share ReZero alpha between the attention share_rezero: bool. Whether to share ReZero alpha between the attention
layer and the ffn layer. This option is specific to ReZero. layer and the ffn layer. This option is specific to ReZero.
with_dense_inputs: Whether to accept dense embeddings as the input.
""" """
def __init__( def __init__(
...@@ -413,6 +414,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -413,6 +414,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
logging.warning('List inputs to %s are discouraged.', self.__class__) logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3: if len(inputs) == 3:
word_ids, mask, type_ids = inputs word_ids, mask, type_ids = inputs
dense_inputs = None
dense_mask = None
dense_type_ids = None
elif len(inputs) == 6:
word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids = inputs
else: else:
raise ValueError('Unexpected inputs to %s with length at %d.' % raise ValueError('Unexpected inputs to %s with length at %d.' %
(self.__class__, len(inputs))) (self.__class__, len(inputs)))
...@@ -420,10 +426,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -420,10 +426,21 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids') word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask') mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids') type_ids = inputs.get('input_type_ids')
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else: else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__) raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence begin so unpool_len can control
# embedding not being pooled.
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1)
type_ids = tf.concat([dense_type_ids, type_ids], axis=1)
mask = tf.concat([dense_mask, mask], axis=1)
# absolute position embeddings # absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = self._type_embedding_layer(type_ids)
......
...@@ -101,6 +101,55 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -101,6 +101,55 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_network_creation_dense(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
pool_type = "avg"
hidden_size = 32
sequence_length = 21
dense_sequence_length = 3
pool_stride = 2
num_layers = 3
# Create a small FunnelTransformerEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
pool_type=pool_type,
max_sequence_length=sequence_length + dense_sequence_length,
unpool_length=0,
transformer_cls="TransformerEncoderBlock")
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
dense_inputs = tf.keras.Input(
shape=(dense_sequence_length, hidden_size), dtype=tf.float32)
dense_mask = tf.keras.Input(shape=(dense_sequence_length,), dtype=tf.int32)
dense_type_ids = tf.keras.Input(
shape=(dense_sequence_length,), dtype=tf.int32)
dict_outputs = test_network(
[word_ids, mask, type_ids, dense_inputs, dense_mask, dense_type_ids])
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, num_layers)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
# Stride=2 compresses sequence length to half the size at each layer.
# For pool_type = max or avg,
# this configuration gives each layer of seq length: 24->12->6->3.
expected_data_shape = [None, 3, hidden_size]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
def test_invalid_stride_and_num_layers(self): def test_invalid_stride_and_num_layers(self):
hidden_size = 32 hidden_size = 32
num_layers = 3 num_layers = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment