beam_search.py 5.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2.
"""

import tensorflow as tf

20
from official.nlp.transformer import beam_search_v1 as v1
21
22
23
24
25
26
27
28
29
30
31
32
from official.transformer.v2 import misc

_StateKeys = v1._StateKeys  # pylint: disable=protected-access


class SequenceBeamSearchV2(v1.SequenceBeamSearch):
  """Implementation of beam search loop in v2."""

  def search(self, initial_ids, initial_cache):
    """Beam search for sequences with highest scores."""
    state, state_shapes = self._create_initial_state(initial_ids, initial_cache)

George Sterpu's avatar
George Sterpu committed
33
34
35
36
37
38
39
    finished_state = tf.nest.map_structure(
      tf.stop_gradient,
      tf.while_loop(self._continue_search,
                    self._search_step,
                    loop_vars=[state],
                    shape_invariants=[state_shapes],
                    parallel_iterations=1))
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    finished_state = finished_state[0]

    alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
    alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
    finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
    finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
    finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]

    # 2.0 changes tf.where behavior. Should make parameters broadcastable.
    finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
    seq_cond = _expand_to_same_rank(finished_cond, finished_seq)
    score_cond = _expand_to_same_rank(finished_cond, finished_scores)

    # Account for corner case where there are no finished sequences for a
    # particular batch item. In that case, return alive sequences for that batch
    # item.
56
57
58
    finished_seq = tf.compat.v2.where(seq_cond, finished_seq, alive_seq)
    finished_scores = tf.compat.v2.where(
        score_cond, finished_scores, alive_log_probs)
59
60
61
    return finished_seq, finished_scores


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
62
63
64
65
66
67
68
69
def sequence_beam_search(symbols_to_logits_fn,
                         initial_ids,
                         initial_cache,
                         vocab_size,
                         beam_size,
                         alpha,
                         max_decode_length,
                         eos_id,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
                         padded_decode=False,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
71
                         dtype="float32"):
72
73
74
75
76
  """Search for sequence of subtoken ids with the largest probability.

  Args:
    symbols_to_logits_fn: A function that takes in ids, index, and cache as
      arguments. The passed in arguments will have shape:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        ids -> A tensor with shape [batch_size * beam_size, index].
        index -> A scalar.
        cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
      The function must return a tuple of logits and new cache:
        logits -> A tensor with shape [batch * beam_size, vocab_size].
        new cache -> A nested dictionary with the same shape/structure as the
          inputted cache.
    initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
      each batch item.
    initial_cache: A dictionary, containing starting decoder variables
      information.
    vocab_size: An integer, the size of tokens.
    beam_size: An integer, the number of beams.
    alpha: A float, defining the strength of length normalization.
    max_decode_length: An integer, the maximum length to decoded a sequence.
    eos_id: An integer, ID of eos token, used to determine when a sequence has
      finished.
    padded_decode: A bool, indicating if max_sequence_length padding is used
      for beam search.
    dtype: A tensorflow data type used for score computation. The default is
      tf.float32.
98
99
100
101
102

  Returns:
    Top decoded sequences [batch_size, beam_size, max_decode_length]
    sequence scores [batch_size, beam_size]
  """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
104
105
  batch_size = (
      initial_ids.shape.as_list()[0] if padded_decode else
      tf.shape(initial_ids)[0])
106
107
  if misc.is_v2():
    sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
Reed's avatar
Reed committed
108
                               beam_size, alpha, max_decode_length, eos_id,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
                               padded_decode, dtype)
110
111
  else:
    sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
Reed's avatar
Reed committed
112
                                beam_size, alpha, max_decode_length, eos_id,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
113
                                padded_decode, dtype)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  return sbs.search(initial_ids, initial_cache)


def _expand_to_same_rank(tensor, target):
  """Expands a given tensor to target's rank to be broadcastable.

  Args:
    tensor: input tensor to tile. Shape: [b, d1, ..., da]
    target: target tensor. Shape: [b, d1, ..., da, ..., dn]

  Returns:
    Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.

  Raises:
    ValueError, if the shape rank of rank tensor/target is None.
  """
  if tensor.shape.rank is None:
    raise ValueError("Expect rank for tensor shape, but got None.")
  if target.shape.rank is None:
    raise ValueError("Expect rank for target shape, but got None.")

  with tf.name_scope("expand_rank"):
    diff_rank = target.shape.rank - tensor.shape.rank
    for _ in range(diff_rank):
      tensor = tf.expand_dims(tensor, -1)
    return tensor