"vscode:/vscode.git/clone" did not exist on "f565b808ed3208c2065b1ba889589eafadea0102"
decoding_module.py 11 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Poorva Potdar's avatar
Poorva Potdar committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Poorva Potdar's avatar
Poorva Potdar committed
15
16
17
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""

import abc
18
from typing import Any, Callable, Dict, Optional, Tuple
Poorva Potdar's avatar
Poorva Potdar committed
19
20
21
22

import tensorflow as tf

from tensorflow.python.framework import dtypes
Poorva Potdar's avatar
Poorva Potdar committed
23
from official.modeling import tf_utils
Poorva Potdar's avatar
Poorva Potdar committed
24

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
Output = Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]
Poorva Potdar's avatar
Poorva Potdar committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
InitialState = Tuple[Dict[str, Any], Dict[str, Any]]


class StateKeys:
  """Keys to dictionary storing the state of Decoding loop."""

  # Variable storing the loop index.
  CUR_INDEX = "CUR_INDEX"

  # Top sequences that are alive for each batch item. Alive sequences are ones
  # that have not generated an EOS token. Sequences that reach EOS are marked as
  # finished and moved to the FINISHED_SEQ tensor.
  # Has shape [batch_size, beam_size, CUR_INDEX + 1] for SequenceBeamSearch and
  # [batch_size, CUR_INDEX + 1] otherwise.
  ALIVE_SEQ = "ALIVE_SEQ"
  # Log probabilities of each alive sequence. Shape [batch_size, beam_size]
  ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
  # Dictionary of cached values for each alive sequence. The cache stores
  # the encoder output, attention bias, and the decoder attention output from
  # the previous iteration.
  ALIVE_CACHE = "ALIVE_CACHE"

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
49
50
51
52
  # The initial model state/cache after model processing the initial token.
  # The cache will be filled if extra_cache_output is true.
  INITIAL_OUTPUT_CACHE = "INITIAL_OUTPUT_CACHE"

Poorva Potdar's avatar
Poorva Potdar committed
53
54
55
56
57
58
59
60
61
62
63
64
65
  # Top finished sequences for each batch item.
  # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
  # shorter than CUR_INDEX + 1 are padded with 0s.
  FINISHED_SEQ = "FINISHED_SEQ"
  # Scores for each finished sequence. Score = log probability / length norm
  # Shape [batch_size, beam_size]
  FINISHED_SCORES = "FINISHED_SCORES"
  # Flags indicating which sequences in the finished sequences are finished.
  # At the beginning, all of the sequences in FINISHED_SEQ are filler values.
  # True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
  FINISHED_FLAGS = "FINISHED_FLAGS"


Poorva Potdar's avatar
Poorva Potdar committed
66
67
68
69
70
71
def log_prob_from_logits(logits):
  return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)


def shape_list(tensor):
  """Return a list of the tensor's shape, and ensure no None values in list."""
Poorva Potdar's avatar
Poorva Potdar committed
72
  return tf_utils.get_shape_list(tensor)
Poorva Potdar's avatar
Poorva Potdar committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def get_shape_keep_last_dim(tensor):
  shape_list_obj = shape_list(tensor)
  for i in range(len(shape_list_obj) - 1):
    shape_list_obj[i] = None

  if isinstance(shape_list_obj[-1], tf.Tensor):
    shape_list_obj[-1] = None
  return tf.TensorShape(shape_list_obj)


def expand_to_same_rank(tensor, target):
  """Expands a given tensor to target's rank to be broadcastable.

  Args:
    tensor: input tensor to tile. Shape: [b, d1, ..., da]
    target: target tensor. Shape: [b, d1, ..., da, ..., dn]

  Returns:
    Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target

  Raises:
    ValueError, if the shape rank of rank tensor/target is None.
  """
  if tensor.shape.rank is None:
    raise ValueError("Expect rank for tensor shape, but got None.")
  if target.shape.rank is None:
    raise ValueError("Expect rank for target shape, but got None.")

  with tf.name_scope("expand_rank"):
    diff_rank = target.shape.rank - tensor.shape.rank
    for _ in range(diff_rank):
      tensor = tf.expand_dims(tensor, -1)
    return tensor


Poorva Potdar's avatar
Poorva Potdar committed
110
111
112
113
114
class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
  """A base class for the API required for decoding (go/decoding-tf-nlp)."""

  def __init__(self,
               length_normalization_fn: Callable[[int, tf.DType], float],
115
               dtype: tf.DType = tf.float32,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
117
               decoding_name: Optional[str] = None,
               extra_cache_output: bool = False):
Poorva Potdar's avatar
Poorva Potdar committed
118
119
120
121
122
123
124
    """Initialize the Decoding Module.

    Args:
      length_normalization_fn: Closure for returning length normalization
      parameter. Function accepts input as length, dtype and returns float.
      dtype: A tensorflow data type used for score computation. The default is
        tf.float32.
125
      decoding_name: an optional name for the decoding loop tensors.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
126
      extra_cache_output: If true, the first cache will be in the states.
Poorva Potdar's avatar
Poorva Potdar committed
127
128
129
    """
    self.length_normalization_fn = length_normalization_fn
    self.dtype = tf.as_dtype(dtype)
130
    self.decoding_name = decoding_name
Poorva Potdar's avatar
Poorva Potdar committed
131

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
133
134
135
  def generate(self,
               initial_ids: tf.Tensor,
               initial_cache: Dict[str, tf.Tensor],
               initial_log_probs: Optional[tf.Tensor] = None) -> Output:
Poorva Potdar's avatar
Poorva Potdar committed
136
137
138
    """Implements the decoding strategy (beam_search or sampling).

    Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
139
140
      initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor
        with shape [batch_size, 1]
Poorva Potdar's avatar
Poorva Potdar committed
141
      initial_cache: dictionary for caching model outputs from previous step.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
142
143
      initial_log_probs: Optionally initial log probs if there is a prefix
        sequence we want to start to decode from.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
144

Poorva Potdar's avatar
Poorva Potdar committed
145
146
147
148
    Returns:
      Tuple of tensors representing
        finished_sequence: shape [batch, max_seq_length]
        finished_scores: [batch]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
149
        first_cache: The cache after init token
Poorva Potdar's avatar
Poorva Potdar committed
150
151
152
153
154
    """
    batch_size = (
        initial_ids.shape.as_list()[0]
        if self.padded_decode else tf.shape(initial_ids)[0])

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
155
156
157
    state, state_shapes = self._create_initial_state(initial_ids, initial_cache,
                                                     batch_size,
                                                     initial_log_probs)
Poorva Potdar's avatar
Poorva Potdar committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    def _generate_step(state):
      topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
          state, batch_size)
      new_finished_flags = self._finished_flags(topk_ids, state)
      alive_state = self._get_new_alive_state(topk_seq,
                                              topk_log_probs,
                                              new_finished_flags,
                                              new_cache)
      finished_state = self._get_new_finished_state(state,
                                                    topk_seq,
                                                    topk_log_probs,
                                                    new_finished_flags,
                                                    batch_size)
      new_state = {
          StateKeys.CUR_INDEX: state[StateKeys.CUR_INDEX] + 1
      }
      new_state.update(alive_state)
      new_state.update(finished_state)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
177
178
179
180
181
182
183
184
185
186
187
      if self.extra_cache_output:
        i = state[StateKeys.CUR_INDEX]
        old_cache = state[StateKeys.INITIAL_OUTPUT_CACHE]

        def update_with_cache(new_state, cache):
          """Updates new_state with cache."""
          new_state.update({StateKeys.INITIAL_OUTPUT_CACHE: cache})

        tf.cond(
            tf.equal(i, 0), lambda: update_with_cache(new_state, new_cache),
            lambda: update_with_cache(new_state, old_cache))
Poorva Potdar's avatar
Poorva Potdar committed
188
189
190
191
192
193
194
195
196
      return [new_state]

    finished_state = tf.nest.map_structure(
        tf.stop_gradient,
        tf.while_loop(
            self._continue_search,
            _generate_step,
            loop_vars=[state],
            shape_invariants=[state_shapes],
197
198
            parallel_iterations=1,
            name=self.decoding_name))
Poorva Potdar's avatar
Poorva Potdar committed
199
200
201
202
    final_state = self._process_finished_state(finished_state[0])
    return final_state

  @abc.abstractmethod
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
203
204
205
206
207
208
  def _create_initial_state(
      self,
      initial_ids: tf.Tensor,
      initial_cache: Dict[str, tf.Tensor],
      batch_size: int,
      initial_log_probs: Optional[tf.Tensor] = None) -> InitialState:
Poorva Potdar's avatar
Poorva Potdar committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    """Return initial state dictionary and its shape invariants."""
    pass

  @abc.abstractmethod
  def _grow_alive_seq(self,
                      state: Dict[str, Any],
                      batch_size: int) -> InternalState:
    """Grow alive sequences by one token.

    Args:
      state: A dictionary with the current loop state.
      batch_size: The given batch size

    Returns:
      Tuple of
      (Top sequences,
       Scores of returned sequences,
       New ids,
       New alive cache)
    """
    pass

  @abc.abstractmethod
  def _get_new_alive_state(
      self,
      new_seq: tf.Tensor,
      new_log_probs: tf.Tensor,
      new_finished_flags: tf.Tensor,
      new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
    """Gather the sequences that are still alive.

    Args:
      new_seq: New sequences generated by growing the current alive sequences
        int32 tensor with shape
      new_log_probs: Log probabilities of new sequences float32 tensor with
        shape
      new_finished_flags: A boolean Tensor indicates which sequences are live.
      new_cache: Dict of cached values for each sequence.

    Returns:
      Dictionary with alive keys from StateKeys.
    """
    pass

  @abc.abstractmethod
  def _get_new_finished_state(self,
                              state: Dict[str, Any],
                              new_seq: tf.Tensor,
                              new_log_probs: tf.Tensor,
                              new_finished_flags: tf.Tensor,
                              batch_size: int) -> Dict[str, tf.Tensor]:
    """Combine new and old finished sequences.

    Args:
      state: A dictionary with the current loop state.
      new_seq: New sequences generated by growing the current alive sequences
        int32 tensor.
      new_log_probs: Log probabilities of new sequences float32 tensor with
        shape.
      new_finished_flags: A boolean Tensor indicates which sequences are live.
      batch_size: The given batch size.

    Returns:
      Dictionary with finished keys from StateKeys.
    """
    pass

  @abc.abstractmethod
  def _process_finished_state(self, finished_state: Dict[str, Any]) -> Output:
    """Process the alive/finished state to return final sequences and scores."""
    pass

  @abc.abstractmethod
  def _continue_search(self, state: Dict[str, Any]) -> tf.Tensor:
    """Returns a bool tensor if the decoding loop should continue."""
    pass

  @abc.abstractmethod
  def _finished_flags(self,
                      topk_ids: tf.Tensor,
                      state: Dict[str, Any]) -> tf.Tensor:
    """Calculate the finished flags."""
    pass

  def inf(self):
    """Returns a value close to infinity, but is still finite in `dtype`.

    This is useful to get a very large value that is still zero when multiplied
    by zero. The floating-point "Inf" value is NaN when multiplied by zero.

    Returns:
      A very large value.
    """
    if self.dtype == dtypes.float32 or self.dtype == dtypes.bfloat16:
      return 1e7
    elif self.dtype == dtypes.float16:
      return dtypes.float16.max
    else:
      raise AssertionError("Invalid dtype: %s" % self.dtype)