Commit bc01957e authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Switch from python slice to tf.slice as python slice emits a

tf.strided_slice which is much more complicated than a tf.slice.

PiperOrigin-RevId: 359792738
parent d4dd827f
...@@ -514,7 +514,11 @@ class SequenceBeamSearch(tf.Module): ...@@ -514,7 +514,11 @@ class SequenceBeamSearch(tf.Module):
max_length_norm = _length_normalization( max_length_norm = _length_normalization(
self.alpha, self.max_decode_length, dtype=self.dtype) self.alpha, self.max_decode_length, dtype=self.dtype)
# Get the best possible scores from alive sequences. # Get the best possible scores from alive sequences.
best_alive_scores = alive_log_probs[:, 0] / max_length_norm # This tf.slice/tf.squeeze is equivalent to alive_log_probs[:, 0] which
# emits a tf.strided_slice. tf.slice is easier to reason about as we aren't
# actually taking a non trivial stride.
best_alive_scores = tf.squeeze(tf.slice(alive_log_probs, [0, 0], [-1, 1]),
axis=1) / max_length_norm
# Compute worst score in finished sequences for each batch element # Compute worst score in finished sequences for each batch element
finished_scores *= tf.cast(finished_flags, finished_scores *= tf.cast(finished_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment