Commit 7e4488ae authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 456624049
parent a6d78dd4
...@@ -1250,7 +1250,10 @@ class Decoder(Module): ...@@ -1250,7 +1250,10 @@ class Decoder(Module):
training: Whether it is training pass, affecting dropouts. training: Whether it is training pass, affecting dropouts.
Returns: Returns:
output of a transformer encoder. output of a transformer encoder including
1. logits: Logits for each word in the vocab.
2. raw_logits: Logits along the moded dimension.
3. cache: Used for decoding in inference mode.
""" """
cfg = self.config cfg = self.config
# Casts inputs to the dtype. # Casts inputs to the dtype.
...@@ -1298,7 +1301,7 @@ class Decoder(Module): ...@@ -1298,7 +1301,7 @@ class Decoder(Module):
logits = logits / math.sqrt(cfg.d_model) logits = logits / math.sqrt(cfg.d_model)
else: else:
logits = self.logits_dense(output) logits = self.logits_dense(output)
return logits, cache return dict(logits=logits, cache=cache, raw_logits=output)
class T5Transformer(Module): class T5Transformer(Module):
...@@ -1392,7 +1395,7 @@ class T5Transformer(Module): ...@@ -1392,7 +1395,7 @@ class T5Transformer(Module):
cache=None, cache=None,
max_decode_len=None, max_decode_len=None,
decode=False, decode=False,
training=False): training=False) -> Dict[str, tf.Tensor]:
eligible_inputs_array = [] eligible_inputs_array = []
if encoder_input_tokens is not None: if encoder_input_tokens is not None:
eligible_inputs = tf.cast( eligible_inputs = tf.cast(
...@@ -1449,7 +1452,7 @@ class T5Transformer(Module): ...@@ -1449,7 +1452,7 @@ class T5Transformer(Module):
decoder_mask = (1.0 - tf.cast(decoder_mask, self.compute_dtype)) * -1e9 decoder_mask = (1.0 - tf.cast(decoder_mask, self.compute_dtype)) * -1e9
encoder_decoder_mask = ( encoder_decoder_mask = (
1.0 - tf.cast(encoder_decoder_mask, self.compute_dtype)) * -1e9 1.0 - tf.cast(encoder_decoder_mask, self.compute_dtype)) * -1e9
logits, cache = self.decoder( outputs = self.decoder(
decoder_input_tokens, decoder_input_tokens,
encoded, encoded,
decode_position=decode_position, decode_position=decode_position,
...@@ -1459,7 +1462,8 @@ class T5Transformer(Module): ...@@ -1459,7 +1462,8 @@ class T5Transformer(Module):
max_decode_len=max_decode_len, max_decode_len=max_decode_len,
decode=decode, decode=decode,
training=training) training=training)
return dict(logits=logits, encoded=encoded, cache=cache) outputs["encoded"] = encoded
return outputs
@tf.Module.with_name_scope @tf.Module.with_name_scope
def __call__(self, def __call__(self,
......
...@@ -403,7 +403,9 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -403,7 +403,9 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
batch_size = 4 batch_size = 4
targets = tf.zeros((4, 8), dtype=tf.int32) targets = tf.zeros((4, 8), dtype=tf.int32)
encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32) encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
logits, cache = decoder(targets, encoded) outputs = decoder(targets, encoded)
logits = outputs["logits"]
cache = outputs["cache"]
self.assertEqual(logits.shape, (4, 8, config.vocab_size)) self.assertEqual(logits.shape, (4, 8, config.vocab_size))
cache = {} cache = {}
...@@ -412,13 +414,15 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -412,13 +414,15 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads, cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv) config.d_kv)
targets = tf.zeros((4, 1), dtype=tf.int32) targets = tf.zeros((4, 1), dtype=tf.int32)
logits, cache = decoder( outputs = decoder(
targets, targets,
encoded, encoded,
decode_position=2, decode_position=2,
cache=cache, cache=cache,
decode=True, decode=True,
max_decode_len=max_decode_len) max_decode_len=max_decode_len)
logits = outputs["logits"]
cache = outputs["cache"]
self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size)) self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
for entry in cache.values(): for entry in cache.values():
for tensor in entry.values(): for tensor in entry.values():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment