Commit fd424b07 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 432997436
parent 4c571a3c
......@@ -1087,16 +1087,17 @@ class Encoder(Module):
@tf.Module.with_name_scope
def __call__(self,
inputs,
inputs=None,
encoder_mask=None,
dense_inputs=None,
training=False):
"""Applies Transformer model on the inputs.
Args:
inputs: input data
inputs: input word ids. Optional if dense data are provided.
encoder_mask: the encoder self-attention mask.
dense_inputs: dense input data, concat after the embedding.
dense_inputs: dense input data. Concat after the embedding if word ids
are provided.
training: whether it is training pass, affecting dropouts.
Returns:
......@@ -1106,16 +1107,27 @@ class Encoder(Module):
if encoder_mask is not None:
encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding)
inputs_array = []
if inputs is not None:
inputs_array.append(
self.input_embed(inputs, one_hot=cfg.one_hot_embedding))
if dense_inputs is not None:
x = tf.concat([x, dense_inputs], axis=1)
inputs_array.append(dense_inputs)
if not inputs_array:
raise ValueError("At least one of inputs and dense_inputs must not be "
"None.")
x = tf.concat(inputs_array, axis=1)
tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
if inputs is not None:
input_length = tf_utils.get_shape_list(inputs)[1]
else:
input_length = 0
position_bias = self.relative_embedding(input_length, input_length)
if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings.
# TODO(yejiayu): If we proceed to video use cases, rework this part.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
......@@ -1320,25 +1332,35 @@ class T5Transformer(Module):
compute_dtype=self.compute_dtype)
def encode(self,
encoder_input_tokens,
encoder_input_tokens=None,
encoder_segment_ids=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
training=False):
eligible_positions = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_position_array = []
if encoder_input_tokens is not None:
eligible_position_array.append(
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype))
if encoder_dense_inputs is not None:
eligible_dense_position = tf.cast(
eligible_dense_positions = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_positions = tf.concat(
[eligible_positions, eligible_dense_position], axis=1)
eligible_position_array.append(eligible_dense_positions)
if not eligible_position_array:
raise ValueError("At least one of encoder_input_tokens and"
" encoder_dense_inputs must be provided.")
eligible_positions = tf.concat(eligible_position_array, axis=1)
encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool)
encoder_segment_id_array = []
if encoder_segment_ids is not None:
encoder_segment_id_array.append(encoder_segment_ids)
if encoder_dense_segment_ids is not None:
encoder_segment_ids = tf.concat(
[encoder_segment_ids, encoder_dense_segment_ids], axis=1)
encoder_segment_id_array.append(encoder_dense_segment_ids)
if encoder_segment_id_array:
encoder_segment_ids = tf.concat(encoder_segment_id_array, axis=1)
segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
......@@ -1353,7 +1375,7 @@ class T5Transformer(Module):
self,
encoded,
decoder_target_tokens,
encoder_input_tokens, # only used for masks
encoder_input_tokens=None, # only used for masks
encoder_dense_inputs=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
......@@ -1364,14 +1386,18 @@ class T5Transformer(Module):
max_decode_len=None,
decode=False,
training=False):
eligible_inputs_array = []
if encoder_input_tokens is not None:
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_inputs_array.append(eligible_inputs)
if encoder_dense_inputs is not None:
eligible_dense_inputs = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype)
eligible_inputs = tf.concat([eligible_inputs, eligible_dense_inputs],
axis=1)
eligible_inputs_array.append(eligible_dense_inputs)
eligible_inputs = tf.concat(eligible_inputs_array, axis=1)
if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens
......@@ -1430,8 +1456,8 @@ class T5Transformer(Module):
@tf.Module.with_name_scope
def __call__(self,
encoder_input_tokens,
decoder_target_tokens,
encoder_input_tokens=None,
decoder_target_tokens=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
decoder_input_tokens=None,
......@@ -1456,7 +1482,7 @@ class T5Transformer(Module):
a dictionary of logits/cache.
"""
encoded = self.encode(
encoder_input_tokens,
encoder_input_tokens=encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids,
encoder_dense_inputs=encoder_dense_inputs,
encoder_dense_segment_ids=encoder_dense_segment_ids,
......
......@@ -372,6 +372,22 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 10, config.d_model))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder_only_dense(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 2, config.d_model))
def test_decoder(self):
max_decode_len = 10
config = t5.T5TransformerParams(
......@@ -515,6 +531,58 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense_only(self, ffn_activations,
logits_via_embedding,
expect_num_variables, layer_sharing,
dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
decoder_inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
decoder_segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
outputs = transformer(
encoder_dense_inputs=dense_inputs,
encoder_dense_segment_ids=dense_segments,
decoder_input_tokens=decoder_inputs,
decoder_target_tokens=decoder_inputs,
decoder_segment_ids=decoder_segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_dense_inputs=dense_inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment