Commit 78015822 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 360985819
parent 7ccaaf17
...@@ -145,7 +145,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value): ...@@ -145,7 +145,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
"""Implementation for sampling stratgies (go/decoding-tf-nlp).""" """Implementation for sampling strategies (go/decoding-tf-nlp)."""
def __init__(self, def __init__(self,
symbols_to_logits_fn, symbols_to_logits_fn,
...@@ -166,8 +166,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -166,8 +166,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.padded_decode = padded_decode self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
self.vocab_size = tf.convert_to_tensor(vocab_size, dtype=tf.int32) self.vocab_size = tf.convert_to_tensor(vocab_size, dtype=tf.int32)
self.max_decode_length = tf.convert_to_tensor(max_decode_length, self.max_decode_length = max_decode_length
dtype=tf.int32)
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32) self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32) self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature, self.sample_temperature = tf.convert_to_tensor(sample_temperature,
...@@ -250,7 +249,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -250,7 +249,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
if inner_value.dtype != self.dtype: if inner_value.dtype != self.dtype:
raise TypeError( raise TypeError(
"initial_cache element for key '%s' has dtype %s that does not " "initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s" % "match sampling_module's dtype of %s. Value: %s" %
(key, value.dtype.name, self.dtype.name, inner_value)) (key, value.dtype.name, self.dtype.name, inner_value))
# Current loop index (starts at 0) # Current loop index (starts at 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment