"tests/vscode:/vscode.git/clone" did not exist on "0af12f1f8a1682833c944354daeba0c9d9c0f342"
Commit cdc4cad7 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 338521847
parent 4f50e2fc
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import abc
from typing import Any, Callable, Dict, Tuple
import tensorflow as tf
from tensorflow.python.framework import dtypes
Output = Tuple[tf.Tensor, tf.Tensor]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
InitialState = Tuple[Dict[str, Any], Dict[str, Any]]
class StateKeys:
"""Keys to dictionary storing the state of Decoding loop."""
# Variable storing the loop index.
CUR_INDEX = "CUR_INDEX"
# Top sequences that are alive for each batch item. Alive sequences are ones
# that have not generated an EOS token. Sequences that reach EOS are marked as
# finished and moved to the FINISHED_SEQ tensor.
# Has shape [batch_size, beam_size, CUR_INDEX + 1] for SequenceBeamSearch and
# [batch_size, CUR_INDEX + 1] otherwise.
ALIVE_SEQ = "ALIVE_SEQ"
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
# Dictionary of cached values for each alive sequence. The cache stores
# the encoder output, attention bias, and the decoder attention output from
# the previous iteration.
ALIVE_CACHE = "ALIVE_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
FINISHED_SEQ = "FINISHED_SEQ"
# Scores for each finished sequence. Score = log probability / length norm
# Shape [batch_size, beam_size]
FINISHED_SCORES = "FINISHED_SCORES"
# Flags indicating which sequences in the finished sequences are finished.
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
FINISHED_FLAGS = "FINISHED_FLAGS"
class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
def __init__(self,
length_normalization_fn: Callable[[int, tf.DType], float],
dtype: tf.DType = tf.float32):
"""Initialize the Decoding Module.
Args:
length_normalization_fn: Closure for returning length normalization
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.length_normalization_fn = length_normalization_fn
self.dtype = tf.as_dtype(dtype)
def generate(self,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor]) -> Output:
"""Implements the decoding strategy (beam_search or sampling).
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step.
Returns:
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
"""
batch_size = (
initial_ids.shape.as_list()[0]
if self.padded_decode else tf.shape(initial_ids)[0])
state, state_shapes = self._create_initial_state(initial_ids,
initial_cache,
batch_size)
def _generate_step(state):
topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
state, batch_size)
new_finished_flags = self._finished_flags(topk_ids, state)
alive_state = self._get_new_alive_state(topk_seq,
topk_log_probs,
new_finished_flags,
new_cache)
finished_state = self._get_new_finished_state(state,
topk_seq,
topk_log_probs,
new_finished_flags,
batch_size)
new_state = {
StateKeys.CUR_INDEX: state[StateKeys.CUR_INDEX] + 1
}
new_state.update(alive_state)
new_state.update(finished_state)
return [new_state]
finished_state = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
self._continue_search,
_generate_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
final_state = self._process_finished_state(finished_state[0])
return final_state
@abc.abstractmethod
def _create_initial_state(self,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
batch_size: int) -> InitialState:
"""Return initial state dictionary and its shape invariants."""
pass
@abc.abstractmethod
def _grow_alive_seq(self,
state: Dict[str, Any],
batch_size: int) -> InternalState:
"""Grow alive sequences by one token.
Args:
state: A dictionary with the current loop state.
batch_size: The given batch size
Returns:
Tuple of
(Top sequences,
Scores of returned sequences,
New ids,
New alive cache)
"""
pass
@abc.abstractmethod
def _get_new_alive_state(
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape
new_log_probs: Log probabilities of new sequences float32 tensor with
shape
new_finished_flags: A boolean Tensor indicates which sequences are live.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys from StateKeys.
"""
pass
@abc.abstractmethod
def _get_new_finished_state(self,
state: Dict[str, Any],
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]:
"""Combine new and old finished sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor.
new_log_probs: Log probabilities of new sequences float32 tensor with
shape.
new_finished_flags: A boolean Tensor indicates which sequences are live.
batch_size: The given batch size.
Returns:
Dictionary with finished keys from StateKeys.
"""
pass
@abc.abstractmethod
def _process_finished_state(self, finished_state: Dict[str, Any]) -> Output:
"""Process the alive/finished state to return final sequences and scores."""
pass
@abc.abstractmethod
def _continue_search(self, state: Dict[str, Any]) -> tf.Tensor:
"""Returns a bool tensor if the decoding loop should continue."""
pass
@abc.abstractmethod
def _finished_flags(self,
topk_ids: tf.Tensor,
state: Dict[str, Any]) -> tf.Tensor:
"""Calculate the finished flags."""
pass
def inf(self):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied
by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns:
A very large value.
"""
if self.dtype == dtypes.float32 or self.dtype == dtypes.bfloat16:
return 1e7
elif self.dtype == dtypes.float16:
return dtypes.float16.max
else:
raise AssertionError("Invalid dtype: %s" % self.dtype)
@staticmethod
def _log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
@staticmethod
def _shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
@staticmethod
def _get_shape_keep_last_dim(tensor):
shape_list_obj = DecodingModule._shape_list(tensor)
for i in range(len(shape_list_obj) - 1):
shape_list_obj[i] = None
if isinstance(shape_list_obj[-1], tf.Tensor):
shape_list_obj[-1] = None
return tf.TensorShape(shape_list_obj)
@staticmethod
def _expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test decoding utility methods."""
import abc
import tensorflow as tf
from official.nlp.modeling.ops import decoding_module
def length_normalization(length, dtype):
"""Return length normalization factor."""
return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)
class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self,
length_normalization_fn=length_normalization,
dtype=tf.float32):
super(TestSubclass, self).__init__(
length_normalization_fn=length_normalization, dtype=dtype)
def _create_initial_state(self, initial_ids, initial_cache, batch_size):
pass
def _grow_alive_seq(self, state, batch_size):
pass
def _process_finished_state(self, finished_state):
pass
def _get_new_finished_state(self, state, new_seq, new_log_probs,
new_finished_flags, batch_size):
pass
def _finished_flags(self, topk_ids, state):
pass
def _continue_search(self, state):
pass
def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
new_cache):
pass
class DecodingModuleTest(tf.test.TestCase):
def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0)
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
shape = decoding_module.DecodingModule._get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5], shape.as_list())
def test_shape_list(self):
x = tf.ones([7, 1])
shape = decoding_module.DecodingModule._shape_list(x)
self.assertAllEqual([7, 1], shape)
def test_inf(self):
d = TestSubclass()
inf_value = d.inf()
self.assertAllEqual(inf_value, tf.constant(10000000., tf.float32))
def test_length_normalization(self):
d = TestSubclass()
normalized_length = d.length_normalization_fn(32, tf.float32)
self.assertAllEqual(normalized_length, tf.constant(1.0, tf.float32))
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sampling module for top_k, top_p and greedy decoding."""
import abc
from typing import Any, Callable, Dict
import numpy as np
import tensorflow as tf
from official.nlp.modeling.ops import decoding_module
class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
"""Implementation for sampling stratgies (go/decoding-tf-nlp)."""
def __init__(self,
symbols_to_logits_fn,
length_normalization_fn: Callable[[int, tf.DType], float],
vocab_size: int,
max_decode_length: int,
eos_id: int,
padded_decode: bool,
top_k: tf.Tensor = None,
sample_temperature: tf.Tensor = None,
dtype: tf.DType = tf.float32):
"""Initialize sampling module."""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
self.length_normalization_fn = length_normalization_fn
self.max_decode_length = max_decode_length
self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype)
self.top_k = top_k
self.sample_temperature = sample_temperature
super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype)
def _grow_alive_seq(self,
state: Dict[str, Any],
batch_size: int) -> decoding_module.InternalState:
"""Grow alive sequences by one token.
This function will implement the decoding strategies like top_p, top_k
and greedy for the choosing the next logit.
Args:
state: A dictionary with the current loop state.
batch_size: The given batch size
Returns:
Tuple of
(Top sequences [batch, curr_index + 1] or [batch, max_decode_length + 1],
Scores of returned sequences [batch, 1],
New ids [batch, 1],
New alive cache)
"""
i = state[decoding_module.StateKeys.CUR_INDEX]
alive_seq = state[decoding_module.StateKeys.ALIVE_SEQ]
alive_log_probs = state[decoding_module.StateKeys.ALIVE_LOG_PROBS]
alive_cache = state[decoding_module.StateKeys.ALIVE_CACHE]
if self.padded_decode:
ids = tf.slice(alive_seq, [0, i], [batch_size, 1])
else:
ids = alive_seq
new_logits, new_cache = self.symbols_to_logits_fn(ids, i, alive_cache)
candidate_log_probs = decoding_module.DecodingModule._log_prob_from_logits(
new_logits)
original_log_probs = candidate_log_probs + alive_log_probs
probs = original_log_probs
topk_log_probs, topk_ids = None, None
if not self.do_sample:
topk_log_probs, topk_ids = self._greedy(probs)
else:
temperature_fn = SamplingModule.sample_logits_with_temperature
probs = tf.cond(self.sample_temperature > 0.0,
lambda: temperature_fn(probs, self.sample_temperature),
lambda: probs)
probs = tf.cond(self.top_k is not None and self.top_k > 1,
lambda: SamplingModule._sample_top_k(probs, self.top_k),
lambda: probs)
topk_ids = tf.random.categorical(probs, dtype=tf.int32, num_samples=1)
topk_log_probs = tf.gather(
original_log_probs, topk_ids, axis=1, batch_dims=1)
if self.padded_decode:
topk_seq = tf.transpose(alive_seq, perm=[1, 0])
topk_seq = tf.tensor_scatter_nd_update(
topk_seq, [[i + 1]], tf.expand_dims(tf.squeeze(topk_ids, -1), 0))
topk_seq = tf.transpose(topk_seq, perm=[1, 0])
else:
topk_seq = tf.concat([alive_seq, topk_ids], axis=-1)
return topk_seq, topk_log_probs, topk_ids, new_cache
def _create_initial_state(self,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
batch_size: int) -> decoding_module.InitialState:
"""Return initial state dictionary and its shape invariants."""
for key, value in initial_cache.items():
for inner_value in tf.nest.flatten(value):
if inner_value.dtype != self.dtype:
raise TypeError(
"initial_cache element for key '%s' has dtype %s that does not "
"match SequenceBeamSearch's dtype of %s. Value: %s" %
(key, value.dtype.name, self.dtype.name, inner_value))
# Current loop index (starts at 0)
cur_index = tf.constant(0)
# Alive sequence with shape [batch_size, 1]
alive_seq = initial_ids
alive_seq = tf.expand_dims(alive_seq, axis=-1)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, self.max_decode_length + 1])
# Initial log probabilities with shape [batch_size, 1].
initial_log_probs = tf.constant([[0.]], dtype=self.dtype)
alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])
alive_cache = initial_cache
# Initialize tensor storing finished sequences [batch_size, 1, 1].
finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
# Set scores of the initial finished seqs to negative infinity.
finished_scores = tf.zeros([batch_size, 1], dtype=self.dtype)
# Initialize finished flags with all False values.
finished_flags = tf.zeros([batch_size, 1], tf.bool)
# Create state dictionary and state shapes.
state = {
decoding_module.StateKeys.CUR_INDEX: cur_index,
decoding_module.StateKeys.ALIVE_SEQ: alive_seq,
decoding_module.StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
decoding_module.StateKeys.ALIVE_CACHE: alive_cache,
decoding_module.StateKeys.FINISHED_SEQ: finished_seq,
decoding_module.StateKeys.FINISHED_SCORES: finished_scores,
decoding_module.StateKeys.FINISHED_FLAGS: finished_flags
}
if self.padded_decode:
state_shape_invariants = {
decoding_module.StateKeys.CUR_INDEX:
tf.TensorShape([]),
decoding_module.StateKeys.ALIVE_SEQ:
tf.TensorShape(
[batch_size, self.max_decode_length + 1]),
decoding_module.StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([batch_size, 1]),
decoding_module.StateKeys.ALIVE_CACHE:
tf.nest.map_structure(lambda state: state.get_shape(),
alive_cache),
decoding_module.StateKeys.FINISHED_SEQ:
tf.TensorShape(
[batch_size, self.max_decode_length + 1]),
decoding_module.StateKeys.FINISHED_SCORES:
tf.TensorShape([batch_size, 1]),
decoding_module.StateKeys.FINISHED_FLAGS:
tf.TensorShape([batch_size, 1])
}
else:
state_shape_invariants = {
decoding_module.StateKeys.CUR_INDEX:
tf.TensorShape([]),
decoding_module.StateKeys.ALIVE_SEQ:
tf.TensorShape([None, None]),
decoding_module.StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, 1]),
decoding_module.StateKeys.ALIVE_CACHE:
tf.nest.map_structure(
decoding_module.DecodingModule._get_shape_keep_last_dim,
alive_cache),
decoding_module.StateKeys.FINISHED_SEQ:
tf.TensorShape([None, None]),
decoding_module.StateKeys.FINISHED_SCORES:
tf.TensorShape([None, 1]),
decoding_module.StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, 1])
}
return state, state_shape_invariants
def _get_new_alive_state(
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, 1]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys.
"""
new_seq = tf.multiply(
new_seq, tf.cast(tf.logical_not(new_finished_flags), new_seq.dtype))
return {
decoding_module.StateKeys.ALIVE_SEQ: new_seq,
decoding_module.StateKeys.ALIVE_LOG_PROBS: new_log_probs,
decoding_module.StateKeys.ALIVE_CACHE: new_cache
}
def _get_new_finished_state(self,
state: Dict[str, Any],
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]:
"""Combine new and old finished sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor [batch, curr_index + 1] or [batch, max_decode_length + 1].
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch, 1].
new_finished_flags: A boolean Tensor indicates which sequences are live.
batch_size: The given batch size.
Returns:
Dictionary with finished keys from StateKeys.
"""
i = state[decoding_module.StateKeys.CUR_INDEX]
finished_seq = state[decoding_module.StateKeys.FINISHED_SEQ]
finished_scores = state[decoding_module.StateKeys.FINISHED_SCORES]
finished_flags = state[decoding_module.StateKeys.FINISHED_FLAGS]
if not self.padded_decode:
finished_seq = tf.concat(
[finished_seq, tf.zeros([batch_size, 1], tf.int32)], axis=-1)
new_scores = new_log_probs
if self.length_normalization_fn is not None:
length_norm = self.length_normalization_fn(i + 1, self.dtype)
new_scores = new_log_probs / length_norm
new_seq = tf.multiply(
new_seq, tf.cast(tf.logical_not(finished_flags), new_seq.dtype))
new_scores = tf.multiply(
new_scores, tf.cast(tf.logical_not(finished_flags), new_scores.dtype))
finished_seq += tf.multiply(new_seq,
tf.cast(new_finished_flags, new_seq.dtype))
finished_scores += tf.multiply(
new_scores, tf.cast(new_finished_flags, new_scores.dtype))
new_finished_flags = tf.logical_or(new_finished_flags, finished_flags)
return {
decoding_module.StateKeys.FINISHED_SEQ: finished_seq,
decoding_module.StateKeys.FINISHED_SCORES: finished_scores,
decoding_module.StateKeys.FINISHED_FLAGS: new_finished_flags
}
def _process_finished_state(
self, finished_state: Dict[str, Any]) -> decoding_module.Output:
"""Process the alive/finished state to return final sequences and scores."""
alive_seq = finished_state[decoding_module.StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[decoding_module.StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[decoding_module.StateKeys.FINISHED_SEQ]
finished_scores = finished_state[decoding_module.StateKeys.FINISHED_SCORES]
finished_flags = finished_state[decoding_module.StateKeys.FINISHED_FLAGS]
finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
if self.length_normalization_fn is not None:
length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype)
alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.DecodingModule._expand_to_same_rank(
finished_cond, finished_seq)
score_cond = decoding_module.DecodingModule._expand_to_same_rank(
finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq, finished_scores)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
def _continue_search(self, state) -> tf.Tensor:
i = state[decoding_module.StateKeys.CUR_INDEX]
return tf.less(i, self.max_decode_length)
def _finished_flags(self, topk_ids, state) -> tf.Tensor:
new_finished_flags = tf.equal(topk_ids, self.eos_id)
new_finished_flags = tf.logical_or(
new_finished_flags, state[decoding_module.StateKeys.FINISHED_FLAGS])
return new_finished_flags
@property
def do_sample(self) -> bool:
"""Returns True if top_p or top_k is enabled."""
# TODO(poorvap) : Add the check for top_p.
if self.top_k is not None:
return True
return False
@staticmethod
def _greedy(log_probs):
"""Returns the top ids and scores based on greedy decoding."""
log_probs, ids = tf.nn.top_k(log_probs, k=1)
return log_probs, ids
@staticmethod
def sample_logits_with_temperature(logits, temperature):
"""Applies a sampling temperature.
Temperature of [0, 1) skews the distribution towards high probability
tokens and lowers the mass in tail distribution.
Args:
logits: Input logits for next token.
temperature: Tensor for specifying the sampling temperature.
Returns:
Logits with applied temperature.
"""
return logits / temperature
@staticmethod
def _sample_top_k(logits, top_k):
"""Chooses top_k logits and sets the others to negative infinity.
Args:
logits: Input logits for next token.
top_k: Tensor to specify the top_k values.
Returns:
Logits with top_k filtering apploed.
"""
top_k_logits = tf.math.top_k(logits, k=top_k)
indices_to_remove = logits < top_k_logits[0][..., -1, None]
top_k_logits = SamplingModule._set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
return top_k_logits
@staticmethod
def _set_tensor_by_indices_to_value(input_tensor, indices, value):
"""Where indices is True, set the value in input_tensor to value.
Args:
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
value_tensor = tf.zeros_like(input_tensor) + value
output_tensor = tf.where(indices, value_tensor, input_tensor)
return output_tensor
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Sampling Strategies."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.modeling.ops import sampling_module
def length_norm(length, dtype):
"""Return length normalization factor."""
return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)
class SamplingModuleTest(tf.test.TestCase, parameterized.TestCase):
cache = {'layer_%d' % layer: {'k': tf.zeros([2, 2, 2, 2], dtype=tf.float32),
'v': tf.zeros([2, 2, 2, 2], dtype=tf.float32)
} for layer in range(2)}
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
[[0.2, 0.4, 0.4], [0.2, 0.7, 0.1],
[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def _get_test_symbols_to_logits_fn(self):
"""Calculates logits of the next tokens."""
def symbols_to_logits_fn(ids, i, cache):
del ids
logits = tf.cast(tf.math.log(self.probabilities[:, i, :]), tf.float32)
return logits, cache
return symbols_to_logits_fn
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_greedy(self, padded_decode):
greedy_obj = sampling_module.SamplingModule(
length_normalization_fn=None,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
padded_decode=padded_decode)
ids, _ = greedy_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual([[9, 1, 2, 2, 2], [1, 1, 1, 2, 2]], ids)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_topk(self, padded_decode):
top_k_obj = sampling_module.SamplingModule(
length_normalization_fn=length_norm,
dtype=tf.float32,
symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
vocab_size=3,
max_decode_length=4,
eos_id=10,
sample_temperature=tf.constant(0.1),
top_k=tf.constant(3),
padded_decode=padded_decode)
ids, _ = top_k_obj.generate(
initial_ids=tf.constant([9, 1]), initial_cache=self.cache)
self.assertAllEqual([2, 5], ids.shape)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment