Commit f07221a7 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 427509556
parent 1cc40617
......@@ -1086,12 +1086,17 @@ class Encoder(Module):
self.output_dropout = Dropout(self.config.dropout_rate,)
@tf.Module.with_name_scope
def __call__(self, inputs, encoder_mask=None, training=False):
def __call__(self,
inputs,
encoder_mask=None,
dense_inputs=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
inputs: input data
encoder_mask: the encoder self-attention mask.
dense_inputs: dense input data, concat after the embedding.
training: whether it is training pass, affecting dropouts.
Returns:
......@@ -1102,11 +1107,20 @@ class Encoder(Module):
encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding)
if dense_inputs is not None:
x = tf.concat([x, dense_inputs], axis=1)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
input_length = tf_utils.get_shape_list(inputs)[1]
position_bias = self.relative_embedding(input_length, input_length)
if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
[0, dense_input_length]])
position_bias = tf.pad(position_bias, paddings, "CONSTANT")
for i in range(cfg.num_layers):
x = self.encoder_layers[i](
......@@ -1308,31 +1322,56 @@ class T5Transformer(Module):
def encode(self,
encoder_input_tokens,
encoder_segment_ids=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
training=False):
eligible_positions = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
if encoder_dense_inputs is not None:
eligible_dense_position = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_positions = tf.concat(
[eligible_positions, eligible_dense_position], axis=1)
encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool)
if encoder_segment_ids is not None:
if encoder_dense_segment_ids is not None:
encoder_segment_ids = tf.concat(
[encoder_segment_ids, encoder_dense_segment_ids], axis=1)
segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
encoder_mask = (1.0 - tf.cast(encoder_mask, self.compute_dtype)) * -1e9
return self.encoder(encoder_input_tokens, encoder_mask, training=training)
return self.encoder(
encoder_input_tokens,
encoder_mask,
encoder_dense_inputs,
training=training)
def decode(
self,
encoded,
decoder_target_tokens,
encoder_input_tokens, # only used for masks
encoder_dense_inputs=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
encoder_dense_segment_ids=None,
decoder_segment_ids=None,
decode_position=None,
cache=None,
max_decode_len=None,
decode=False,
training=False):
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
if encoder_dense_inputs is not None:
eligible_dense_inputs = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_inputs = tf.concat([eligible_inputs, eligible_dense_inputs],
axis=1)
if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens
......@@ -1342,14 +1381,12 @@ class T5Transformer(Module):
tf.cast(
tf.not_equal(tf.ones_like(decoder_target_tokens), 0),
self.compute_dtype),
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype),
eligible_inputs,
dtype=tf.bool)
else:
# Note that, masks should be created using decoder_target_tokens.
eligible_targets = tf.cast(
tf.not_equal(decoder_target_tokens, 0), self.compute_dtype)
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
decoder_mask = tf.math.logical_and(
make_attention_mask(
eligible_targets, eligible_targets, dtype=tf.bool),
......@@ -1365,6 +1402,9 @@ class T5Transformer(Module):
decoder_segment_ids,
tf.equal,
dtype=tf.bool))
if encoder_dense_segment_ids is not None:
encoder_segment_ids = tf.concat(
[encoder_segment_ids, encoder_dense_segment_ids], axis=1)
encoder_decoder_mask = tf.math.logical_and(
encoder_decoder_mask,
make_attention_mask(
......@@ -1392,6 +1432,8 @@ class T5Transformer(Module):
def __call__(self,
encoder_input_tokens,
decoder_target_tokens,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
......@@ -1401,9 +1443,12 @@ class T5Transformer(Module):
Args:
encoder_input_tokens: input tokens to the encoder.
decoder_target_tokens: target tokens to the decoder.
encoder_dense_inputs: input dense vectors to the encoder.
encoder_dense_segment_ids: dense input segmentation info for packed
decoder_input_tokens: input tokens to the decoder, only required for
training.
encoder_segment_ids: input segmentation info for packed examples.
examples.
decoder_segment_ids: target segmentation info for packed examples.
training: whether it is training pass, affecting dropouts.
......@@ -1413,13 +1458,17 @@ class T5Transformer(Module):
encoded = self.encode(
encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
encoder_dense_inputs=encoder_dense_inputs,
encoder_dense_segment_ids=encoder_dense_segment_ids,
training=training)
outputs = self.decode(
encoded=encoded,
decoder_target_tokens=decoder_target_tokens,
encoder_input_tokens=encoder_input_tokens, # only used for masks.
encoder_dense_inputs=encoder_dense_inputs, # only used for masks.
decoder_input_tokens=decoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
encoder_dense_segment_ids=encoder_dense_segment_ids,
decoder_segment_ids=decoder_segment_ids,
training=training)
outputs["encoded"] = encoded
......
......@@ -354,6 +354,24 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
encoded = encoder(tf.zeros((4, 8), dtype=tf.int32))
self.assertEqual(encoded.shape, (4, 8, config.d_model))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder_with_dense(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(
tf.zeros((4, 8), dtype=tf.int32),
dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 10, config.d_model))
def test_decoder(self):
max_decode_len = 10
config = t5.T5TransformerParams(
......@@ -445,6 +463,58 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense(self, ffn_activations, logits_via_embedding,
expect_num_variables, layer_sharing, dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
encoder_dense_inputs=dense_inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
encoder_dense_segment_ids=dense_segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoder_dense_inputs=dense_inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment