"test/vscode:/vscode.git/clone" did not exist on "85d2365d337ca81eb353645bca15a199cc348847"
Commit 92c0288a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377944468
parent c7644458
......@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length)
# Have we reached max decoding length?
not_at_end = tf.less(i, self.max_decode_length)
# Have all sampled sequences reached an EOS?
all_has_eos = tf.reduce_all(
state[decoding_module.StateKeys.FINISHED_FLAGS],
axis=None,
name="search_finish_cond")
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos))
def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment