Commit ef800b03 authored by xinliupitt's avatar xinliupitt
Browse files

no predict func

parent 0490e860
......@@ -172,7 +172,6 @@ class Seq2SeqTransformer(tf.keras.Model):
attention_mask=attention_mask)
if targets is None:
# return self.predict(encoder_outputs, attention_bias, training)
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
if self.params["padded_decode"]:
......@@ -369,70 +368,6 @@ class Seq2SeqTransformer(tf.keras.Model):
return symbols_to_logits_fn
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence."""
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
if self.params["padded_decode"]:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"])
symbols_to_logits_fn = self._get_symbols_to_logits_fn(
max_decode_length, training)
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
num_heads = self.params["num_heads"]
dim_per_head = self.params["hidden_size"] // num_heads
cache = {
str(layer): {
"key":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"]),
"value":
tf.zeros([
batch_size, init_decode_length, num_heads, dim_per_head
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params["vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"])
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
class TransformerEncoder(tf.keras.layers.Layer):
"""Transformer decoder stack.
Like the encoder stack, the decoder stack is made up of N identical layers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment