dataset.py 11.1 KB
Newer Older
Katherine Wu's avatar
Katherine Wu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipeline for the transformer model to read, filter, and batch examples.

Two things to note in the pipeline:

1. Batching scheme

   The examples encoded in the TFRecord files contain data in the format:
     {"inputs": [variable length array of integers],
      "targets": [variable length array of integers]}
   Where integers in the arrays refer to tokens in the English and German vocab
   file (named `vocab.ende.32768`).

   Prior to batching, elements in the dataset are grouped by length (max between
   "inputs" and "targets" length). Each group is then batched such that:
     group_batch_size * length <= batch_size.

   Another way to view batch_size is the maximum number of tokens in each batch.

   Once batched, each element in the dataset will have the shape:
     {"inputs": [group_batch_size, padded_input_length],
      "targets": [group_batch_size, padded_target_length]}
   Lengths are padded to the longest "inputs" or "targets" sequence in the batch
   (padded_input_length and padded_target_length can be different).

   This batching scheme decreases the fraction of padding tokens per training
   batch, thus improving the training speed significantly.

2. Shuffling

   While training, the dataset is shuffled in two places in the code. The first
   is the list of training files. Second, while reading records using
   `parallel_interleave`, the `sloppy` argument is used to generate randomness
   in the order of the examples.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

54
import math
Katherine Wu's avatar
Katherine Wu committed
55
56
57
58
import os

import tensorflow as tf

59
60
from official.utils.misc import model_helpers

Katherine Wu's avatar
Katherine Wu committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER = 8 * 1000 * 1000

# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1


def _load_records(filename):
  """Read file and return a dataset of tf.Examples."""
  return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)


def _parse_example(serialized_example):
  """Return inputs and targets Tensors from a serialized tf.Example."""
  data_fields = {
      "inputs": tf.VarLenFeature(tf.int64),
      "targets": tf.VarLenFeature(tf.int64)
  }
  parsed = tf.parse_single_example(serialized_example, data_fields)
  inputs = tf.sparse_tensor_to_dense(parsed["inputs"])
  targets = tf.sparse_tensor_to_dense(parsed["targets"])
  return inputs, targets


def _filter_max_length(example, max_length=256):
  """Indicates whether the example's length is lower than the maximum length."""
  return tf.logical_and(tf.size(example[0]) <= max_length,
                        tf.size(example[1]) <= max_length)


def _get_example_length(example):
  """Returns the maximum length between the example inputs and targets."""
  length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
  return length


def _create_min_max_boundaries(
    max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE):
  """Create min and max boundary lists up to max_length.

  For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
  returned values will be:
    buckets_min = [0, 4, 8, 16, 24]
    buckets_max = [4, 8, 16, 24, 25]

  Args:
    max_length: The maximum length of example in dataset.
    min_boundary: Minimum length in boundary.
    boundary_scale: Amount to scale consecutive boundaries in the list.

  Returns:
    min and max boundary lists

  """
  # Create bucket boundaries list by scaling the previous boundary or adding 1
  # (to ensure increasing boundary sizes).
  bucket_boundaries = []
  x = min_boundary
  while x < max_length:
    bucket_boundaries.append(x)
    x = max(x + 1, int(x * boundary_scale))

  # Create min and max boundary lists from the initial list.
  buckets_min = [0] + bucket_boundaries
  buckets_max = bucket_boundaries + [max_length + 1]
  return buckets_min, buckets_max


def _batch_examples(dataset, batch_size, max_length):
  """Group examples by similar lengths, and return batched dataset.

  Each batch of similar-length examples are padded to the same length, and may
  have different number of elements in each batch, such that:
    group_batch_size * padded_length <= batch_size.

  This decreases the number of padding tokens per batch, which improves the
  training speed.

  Args:
    dataset: Dataset of unbatched examples.
    batch_size: Max number of tokens per batch of examples.
    max_length: Max number of tokens in an example input or target sequence.

  Returns:
    Dataset of batched examples with similar lengths.
  """
  # Get min and max boundary lists for each example. These are used to calculate
  # the `bucket_id`, which is the index at which:
  # buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
  # Note that using both min and max lists improves the performance.
  buckets_min, buckets_max = _create_min_max_boundaries(max_length)

  # Create list of batch sizes for each bucket_id, so that
  # bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
  bucket_batch_sizes = [batch_size // x for x in buckets_max]
  # bucket_id will be a tensor, so convert this list to a tensor as well.
  bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)

  def example_to_bucket_id(example_input, example_target):
    """Return int64 bucket id for this example, calculated based on length."""
    seq_length = _get_example_length((example_input, example_target))

    # TODO: investigate whether removing code branching improves performance.
    conditions_c = tf.logical_and(
        tf.less_equal(buckets_min, seq_length),
        tf.less(seq_length, buckets_max))
    bucket_id = tf.reduce_min(tf.where(conditions_c))
    return bucket_id

  def window_size_fn(bucket_id):
    """Return number of examples to be grouped when given a bucket id."""
    return bucket_batch_sizes[bucket_id]

  def batching_fn(bucket_id, grouped_dataset):
    """Batch and add padding to a dataset of elements with similar lengths."""
    bucket_batch_size = window_size_fn(bucket_id)

    # Batch the dataset and add padding so that all input sequences in the
    # examples have the same length, and all target sequences have the same
    # lengths as well. Resulting lengths of inputs and targets can differ.
    return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))

186
  return dataset.apply(tf.contrib.data.group_by_window(
Katherine Wu's avatar
Katherine Wu committed
187
188
189
190
191
192
193
      key_func=example_to_bucket_id,
      reduce_func=batching_fn,
      window_size=None,
      window_size_func=window_size_fn))


def _read_and_batch_from_files(
194
195
    file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
    static_batch=False):
Katherine Wu's avatar
Katherine Wu committed
196
197
198
199
200
201
  """Create dataset where each item is a dict of "inputs" and "targets".

  Args:
    file_pattern: String used to match the input TFRecord files.
    batch_size: Maximum number of tokens per batch of examples
    max_length: Maximum number of tokens per example
202
    num_parallel_calls: Number of cpu cores for parallel input processing.
Katherine Wu's avatar
Katherine Wu committed
203
204
205
    shuffle: If true, randomizes order of elements.
    repeat: Number of times to repeat the dataset. If None, the dataset is
      repeated forever.
206
207
208
209
210
211
212
213
214
215
216
    static_batch: Whether the batches in the dataset should have static shapes.
      If True, the input is batched so that every batch has the
      shape [batch_size // max_length, max_length]. If False, the input is
      grouped by length, and batched so that batches may have different
      shapes [N, M], where:
        N * M <= batch_size
        M <= max_length
      In general, this setting should be False. Dynamic shapes allow the inputs
      to be grouped so that the number of padding tokens is minimized, and helps
      model training. In cases where the input shape must be static
      (e.g. running on TPU), this setting should be set to True.
Katherine Wu's avatar
Katherine Wu committed
217
218
219
220

  Returns:
    tf.data.Dataset object containing examples loaded from the files.
  """
221
  dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
Katherine Wu's avatar
Katherine Wu committed
222
223
224
225

  # Read files and interleave results. When training, the order of the examples
  # will be non-deterministic.
  dataset = dataset.apply(
226
      tf.contrib.data.parallel_interleave(
227
          _load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
Katherine Wu's avatar
Katherine Wu committed
228
229
230
231

  # Parse each tf.Example into a dictionary
  # TODO: Look into prefetch_input_elements for performance optimization.
  dataset = dataset.map(_parse_example,
232
                        num_parallel_calls=num_parallel_calls)
Katherine Wu's avatar
Katherine Wu committed
233
234
235
236

  # Remove examples where the input or target length exceeds the maximum length,
  dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))

237
  if static_batch:
238
239
    dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder(
        batch_size // max_length, ([max_length], [max_length])))
240
241
242
243
  else:
    # Group and batch such that each batch has examples of similar length.
    dataset = _batch_examples(dataset, batch_size, max_length)

Katherine Wu's avatar
Katherine Wu committed
244
245
246
  dataset = dataset.repeat(repeat)

  # Prefetch the next element to improve speed of input pipeline.
247
  dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
Katherine Wu's avatar
Katherine Wu committed
248
249
250
  return dataset


251
252
253
254
255
256
257
258
259
260
261
262
263
def _generate_synthetic_data(params):
  """Create synthetic data based on the parameter batch size."""
  batch = length = int(math.sqrt(params["batch_size"]))
  return model_helpers.generate_synthetic_data(
      input_shape=tf.TensorShape([batch, length]),
      input_value=1,
      input_dtype=tf.int32,
      label_shape=tf.TensorShape([batch, length]),
      label_value=1,
      label_dtype=tf.int32,
  )


Katherine Wu's avatar
Katherine Wu committed
264
265
def train_input_fn(params):
  """Load and return dataset of batched examples for use during training."""
266
267
268
  file_pattern = os.path.join(params["data_dir"] or "", "*train*")
  if params["use_synthetic_data"]:
    return _generate_synthetic_data(params)
Katherine Wu's avatar
Katherine Wu committed
269
  return _read_and_batch_from_files(
270
271
272
      file_pattern, params["batch_size"], params["max_length"],
      params["num_parallel_calls"], shuffle=True,
      repeat=params["repeat_dataset"], static_batch=params["static_batch"])
Katherine Wu's avatar
Katherine Wu committed
273
274
275
276


def eval_input_fn(params):
  """Load and return dataset of batched examples for use during evaluation."""
277
278
279
  file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
  if params["use_synthetic_data"]:
    return _generate_synthetic_data(params)
Katherine Wu's avatar
Katherine Wu committed
280
  return _read_and_batch_from_files(
281
282
283
      file_pattern, params["batch_size"], params["max_length"],
      params["num_parallel_calls"], shuffle=False, repeat=1,
      static_batch=params["static_batch"])