Commit 5460577d authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 344068616
parent 0f7580bd
......@@ -58,6 +58,58 @@ class StateKeys:
FINISHED_FLAGS = "FINISHED_FLAGS"
def log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
def get_shape_keep_last_dim(tensor):
shape_list_obj = shape_list(tensor)
for i in range(len(shape_list_obj) - 1):
shape_list_obj[i] = None
if isinstance(shape_list_obj[-1], tf.Tensor):
shape_list_obj[-1] = None
return tf.TensorShape(shape_list_obj)
def expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
......@@ -233,57 +285,5 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
else:
raise AssertionError("Invalid dtype: %s" % self.dtype)
@staticmethod
def _log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
@staticmethod
def _shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
@staticmethod
def _get_shape_keep_last_dim(tensor):
shape_list_obj = DecodingModule._shape_list(tensor)
for i in range(len(shape_list_obj) - 1):
shape_list_obj[i] = None
if isinstance(shape_list_obj[-1], tf.Tensor):
shape_list_obj[-1] = None
return tf.TensorShape(shape_list_obj)
@staticmethod
def _expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
......@@ -62,12 +62,12 @@ class DecodingModuleTest(tf.test.TestCase):
def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0)
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
shape = decoding_module.DecodingModule._get_shape_keep_last_dim(x)
shape = decoding_module.get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5], shape.as_list())
def test_shape_list(self):
x = tf.ones([7, 1])
shape = decoding_module.DecodingModule._shape_list(x)
shape = decoding_module.shape_list(x)
self.assertAllEqual([7, 1], shape)
def test_inf(self):
......
......@@ -23,6 +23,127 @@ import tensorflow as tf
from official.nlp.modeling.ops import decoding_module
def greedy(log_probs):
"""Returns the top ids and scores based on greedy decoding."""
log_probs, ids = tf.nn.top_k(log_probs, k=1)
return log_probs, ids
def sample_logits_with_temperature(logits, temperature):
"""Applies a sampling temperature.
Temperature skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return logits / temperature
def sample_top_k(logits, top_k):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering applied.
"""
top_k_logits = tf.math.top_k(logits, k=top_k)
indices_to_remove = logits < top_k_logits[0][..., -1, None]
top_k_logits = set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
return top_k_logits
def sample_top_p(logits, top_p):
"""Chooses most probable logits with cumulative probabilities upto top_p.
Sets the remaining logits to negative infinity.
Args:
logits: Input logits for next token.
top_p: Float tensor with a value >=0 and < 1.0
Returns:
Logits with top_p filtering applied.
"""
sorted_indices = tf.argsort(logits, direction="DESCENDING")
# Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather = tf.expand_dims(tf.range(0, logits.shape[0]), axis=1)
range_for_gather = tf.tile(range_for_gather * logits.shape[1],
[1, logits.shape[1]]) + sorted_indices
flattened_logits = tf.reshape(logits, [-1])
flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
sorted_logits = tf.reshape(
tf.gather(flattened_logits, flattened_sorted_indices),
[logits.shape[0], logits.shape[1]])
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold.
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above threshold.
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat([
tf.zeros_like(sorted_indices_to_remove[:, :1]),
sorted_indices_to_remove[:, 1:]
], -1)
# Scatter sorted indices to original indexes.
indices_to_remove = scatter_values_on_batch_indices(
sorted_indices_to_remove, sorted_indices)
top_p_logits = set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
return top_p_logits
def scatter_values_on_batch_indices(values, batch_indices):
"""Scatter `values` into a tensor using `batch_indices`.
Args:
values: tensor of shape [batch_size, vocab_size] containing the values to
scatter
batch_indices: tensor of shape [batch_size, vocab_size] containing the
indices to insert (should be a permutation in range(0, n))
Returns:
Tensor of shape [batch_size, vocab_size] with values inserted at
batch_indices
"""
tensor_shape = decoding_module.shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape(
tf.broadcast_to(
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1),
tensor_shape), [1, -1])
pair_indices = tf.transpose(
tf.concat([broad_casted_batch_dims,
tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices,
tf.reshape(values, [-1]), tensor_shape)
def set_tensor_by_indices_to_value(input_tensor, indices, value):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor = tf.zeros_like(input_tensor) + value
output_tensor = tf.where(indices, value_tensor, input_tensor)
return output_tensor
class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
"""Implementation for sampling stratgies (go/decoding-tf-nlp)."""
......@@ -33,19 +154,25 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
max_decode_length: int,
eos_id: int,
padded_decode: bool,
top_k: tf.Tensor = None,
sample_temperature: tf.Tensor = None,
top_k=0,
top_p=1.0,
sample_temperature=0.0,
enable_greedy: bool = True,
dtype: tf.DType = tf.float32):
"""Initialize sampling module."""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
self.length_normalization_fn = length_normalization_fn
self.max_decode_length = max_decode_length
self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype)
self.top_k = top_k
self.sample_temperature = sample_temperature
self.vocab_size = tf.convert_to_tensor(vocab_size, dtype=tf.int32)
self.max_decode_length = tf.convert_to_tensor(max_decode_length,
dtype=tf.int32)
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature,
dtype=tf.float32)
self.enable_greedy = enable_greedy
super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype)
......@@ -79,23 +206,29 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
ids = alive_seq
new_logits, new_cache = self.symbols_to_logits_fn(ids, i, alive_cache)
candidate_log_probs = decoding_module.DecodingModule._log_prob_from_logits(
candidate_log_probs = decoding_module.log_prob_from_logits(
new_logits)
original_log_probs = candidate_log_probs + alive_log_probs
probs = original_log_probs
topk_log_probs, topk_ids = None, None
if not self.do_sample:
topk_log_probs, topk_ids = self._greedy(probs)
if self.enable_greedy:
topk_log_probs, topk_ids = greedy(original_log_probs)
else:
temperature_fn = self.sample_logits_with_temperature
probs = tf.cond(self.sample_temperature > 0.0,
lambda: temperature_fn(probs, self.sample_temperature),
lambda: probs)
probs = tf.cond(self.top_k is not None and self.top_k > 1,
lambda: self._sample_top_k(probs, self.top_k),
lambda: probs)
topk_ids = tf.random.categorical(probs, dtype=tf.int32, num_samples=1)
temperature_fn = sample_logits_with_temperature
sampled_logits = tf.cond(
self.sample_temperature > 0.0,
lambda: temperature_fn(new_logits, self.sample_temperature),
lambda: new_logits)
sampled_logits = tf.cond(
self.top_k > 0,
lambda: sample_top_k(sampled_logits, self.top_k),
lambda: sampled_logits)
sampled_logits = tf.cond(
self.top_p < 1,
lambda: sample_top_p(sampled_logits, self.top_p),
lambda: sampled_logits)
topk_ids = tf.random.categorical(
sampled_logits, dtype=tf.int32, num_samples=1)
topk_log_probs = tf.gather(
original_log_probs, topk_ids, axis=1, batch_dims=1)
if self.padded_decode:
......@@ -185,7 +318,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
tf.TensorShape([None, 1]),
decoding_module.StateKeys.ALIVE_CACHE:
tf.nest.map_structure(
decoding_module.DecodingModule._get_shape_keep_last_dim,
decoding_module.get_shape_keep_last_dim,
alive_cache),
decoding_module.StateKeys.FINISHED_SEQ:
tf.TensorShape([None, None]),
......@@ -288,9 +421,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype)
alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.DecodingModule._expand_to_same_rank(
seq_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_seq)
score_cond = decoding_module.DecodingModule._expand_to_same_rank(
score_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq, finished_scores)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
......@@ -306,68 +439,6 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags
@property
def do_sample(self) -> bool:
"""Returns True if top_p or top_k is enabled."""
# TODO(poorvap) : Add the check for top_p.
if self.top_k is not None:
return True
return False
@staticmethod
def _greedy(log_probs):
"""Returns the top ids and scores based on greedy decoding."""
log_probs, ids = tf.nn.top_k(log_probs, k=1)
return log_probs, ids
@staticmethod
def sample_logits_with_temperature(logits, temperature):
"""Applies a sampling temperature.
Temperature of [0, 1) skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return logits / temperature
@staticmethod
def _sample_top_k(logits, top_k):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering apploed.
"""
top_k_logits = tf.math.top_k(logits, k=top_k)
indices_to_remove = logits < top_k_logits[0][..., -1, None]
top_k_logits = SamplingModule._set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
return top_k_logits
@staticmethod
def _set_tensor_by_indices_to_value(input_tensor, indices, value):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor = tf.zeros_like(input_tensor) + value
output_tensor = tf.where(indices, value_tensor, input_tensor)
return output_tensor
......
......@@ -24,6 +24,8 @@ def length_norm(length, dtype):
"""Return length normalization factor."""
return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)
greedy_expected = tf.constant([[9, 1, 2, 2, 2], [1, 1, 1, 2, 2]])
class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
......@@ -32,7 +34,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
} for layer in range(2)}
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
[[0.2, 0.4, 0.4], [0.2, 0.7, 0.1],
[[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def _get_test_symbols_to_logits_fn(self):
......@@ -58,7 +60,7 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
padded_decode=padded_decode)
ids, _ = greedy_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual([[9, 1, 2, 2, 2], [1, 1, 1, 2, 2]], ids)
self.assertAllEqual(greedy_expected, ids)
@parameterized.named_parameters([
('padded_decode_true', True),
......@@ -72,12 +74,104 @@ class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=tf.constant(0.1),
sample_temperature=tf.constant(1.0),
top_k=tf.constant(3),
padded_decode=padded_decode)
padded_decode=padded_decode,
enable_greedy=False)
tf.random.set_seed(1)
ids, _ = top_k_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual([2, 5], ids.shape)
top_k_expected = tf.constant([[9, 1, 0, 2, 2], [1, 0, 1, 1, 0]])
self.assertAllEqual(top_k_expected, ids)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_topp(self, padded_decode):
top_p_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=tf.constant(1.0),
top_p=tf.constant(0.9),
padded_decode=padded_decode,
enable_greedy=False)
tf.random.set_seed(1)
ids, _ = top_p_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
top_p_expected = tf.constant([[9, 1, 0, 2, 2], [1, 0, 1, 2, 0]])
self.assertAllEqual(top_p_expected, ids)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_sampling_equivalent_greedy(self, padded_decode):
# Ensure that p=0.0 with no sample temperature is same as greedy.
top_p_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=0.0,
top_p=tf.constant(0.0),
padded_decode=padded_decode,
enable_greedy=False)
ids, _ = top_p_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual(greedy_expected, ids)
# Ensure that k=1 with no sample temperature is same as greedy.
top_k_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=0.0,
top_k=tf.constant(1),
padded_decode=padded_decode,
enable_greedy=False)
ids, _ = top_k_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
# Ensure that low sample temperature results in Sharp Distribution (greedy).
low_temperature_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=0.0001,
padded_decode=padded_decode)
ids, _ = low_temperature_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual(greedy_expected, ids)
# Ensure that high sample temperature results in Flat Distribution (random).
high_temperature_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=10.0,
padded_decode=padded_decode,
enable_greedy=False)
tf.random.set_seed(1)
ids, _ = high_temperature_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
expected = tf.constant([[9, 0, 0, 2, 2], [1, 0, 0, 0, 0]])
self.assertAllEqual(expected, ids)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment