Commit a1fd33c5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377944468
parent 1fa648a7
...@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def _continue_search(self, state) -> tf.Tensor: def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX] i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length) # Have we reached max decoding length?
not_at_end = tf.less(i, self.max_decode_length)
# Have all sampled sequences reached an EOS?
all_has_eos = tf.reduce_all(
state[decoding_module.StateKeys.FINISHED_FLAGS],
axis=None,
name="search_finish_cond")
return tf.logical_and(not_at_end, tf.logical_not(all_has_eos))
def _finished_flags(self, topk_ids, state) -> tf.Tensor: def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id) new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or( new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS]) new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags return new_finished_flags
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment