"tools/vscode:/vscode.git/clone" did not exist on "f02cf7ca64b8d5def8438345a4be090bc84d5d1f"
Commit bc01957e authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Switch from python slice to tf.slice as python slice emits a

tf.strided_slice which is much more complicated than a tf.slice.

PiperOrigin-RevId: 359792738
parent d4dd827f
......@@ -514,7 +514,11 @@ class SequenceBeamSearch(tf.Module):
max_length_norm = _length_normalization(
self.alpha, self.max_decode_length, dtype=self.dtype)
# Get the best possible scores from alive sequences.
best_alive_scores = alive_log_probs[:, 0] / max_length_norm
# This tf.slice/tf.squeeze is equivalent to alive_log_probs[:, 0] which
# emits a tf.strided_slice. tf.slice is easier to reason about as we aren't
# actually taking a non trivial stride.
best_alive_scores = tf.squeeze(tf.slice(alive_log_probs, [0, 0], [-1, 1]),
axis=1) / max_length_norm
# Compute worst score in finished sequences for each batch element
finished_scores *= tf.cast(finished_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment