Commit 7e4488ae authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 456624049
parent a6d78dd4
......@@ -1250,7 +1250,10 @@ class Decoder(Module):
training: Whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
output of a transformer encoder including
1. logits: Logits for each word in the vocab.
2. raw_logits: Logits along the moded dimension.
3. cache: Used for decoding in inference mode.
"""
cfg = self.config
# Casts inputs to the dtype.
......@@ -1298,7 +1301,7 @@ class Decoder(Module):
logits = logits / math.sqrt(cfg.d_model)
else:
logits = self.logits_dense(output)
return logits, cache
return dict(logits=logits, cache=cache, raw_logits=output)
class T5Transformer(Module):
......@@ -1392,7 +1395,7 @@ class T5Transformer(Module):
cache=None,
max_decode_len=None,
decode=False,
training=False):
training=False) -> Dict[str, tf.Tensor]:
eligible_inputs_array = []
if encoder_input_tokens is not None:
eligible_inputs = tf.cast(
......@@ -1449,7 +1452,7 @@ class T5Transformer(Module):
decoder_mask = (1.0 - tf.cast(decoder_mask, self.compute_dtype)) * -1e9
encoder_decoder_mask = (
1.0 - tf.cast(encoder_decoder_mask, self.compute_dtype)) * -1e9
logits, cache = self.decoder(
outputs = self.decoder(
decoder_input_tokens,
encoded,
decode_position=decode_position,
......@@ -1459,7 +1462,8 @@ class T5Transformer(Module):
max_decode_len=max_decode_len,
decode=decode,
training=training)
return dict(logits=logits, encoded=encoded, cache=cache)
outputs["encoded"] = encoded
return outputs
@tf.Module.with_name_scope
def __call__(self,
......
......@@ -403,7 +403,9 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
batch_size = 4
targets = tf.zeros((4, 8), dtype=tf.int32)
encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
logits, cache = decoder(targets, encoded)
outputs = decoder(targets, encoded)
logits = outputs["logits"]
cache = outputs["cache"]
self.assertEqual(logits.shape, (4, 8, config.vocab_size))
cache = {}
......@@ -412,13 +414,15 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
config.d_kv)
targets = tf.zeros((4, 1), dtype=tf.int32)
logits, cache = decoder(
outputs = decoder(
targets,
encoded,
decode_position=2,
cache=cache,
decode=True,
max_decode_len=max_decode_len)
logits = outputs["logits"]
cache = outputs["cache"]
self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
for entry in cache.values():
for tensor in entry.values():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment