Commit 49c66406 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341433786
parent dcc224c6
...@@ -88,12 +88,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -88,12 +88,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
if not self.do_sample: if not self.do_sample:
topk_log_probs, topk_ids = self._greedy(probs) topk_log_probs, topk_ids = self._greedy(probs)
else: else:
temperature_fn = SamplingModule.sample_logits_with_temperature temperature_fn = self.sample_logits_with_temperature
probs = tf.cond(self.sample_temperature > 0.0, probs = tf.cond(self.sample_temperature > 0.0,
lambda: temperature_fn(probs, self.sample_temperature), lambda: temperature_fn(probs, self.sample_temperature),
lambda: probs) lambda: probs)
probs = tf.cond(self.top_k is not None and self.top_k > 1, probs = tf.cond(self.top_k is not None and self.top_k > 1,
lambda: SamplingModule._sample_top_k(probs, self.top_k), lambda: self._sample_top_k(probs, self.top_k),
lambda: probs) lambda: probs)
topk_ids = tf.random.categorical(probs, dtype=tf.int32, num_samples=1) topk_ids = tf.random.categorical(probs, dtype=tf.int32, num_samples=1)
topk_log_probs = tf.gather( topk_log_probs = tf.gather(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment