base.py 8.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base dataset builder classes for AstroWaveNet input pipelines."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import six

import tensorflow as tf

26
from tf_util import configdict
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from astronet.ops import dataset_ops


@six.add_metaclass(abc.ABCMeta)
class DatasetBuilder(object):
  """Base class for building a dataset input pipeline for AstroWaveNet."""

  def __init__(self, config_overrides=None):
    """Initializes the dataset builder.

    Args:
      config_overrides: Dict or ConfigDict containing overrides to the default
        configuration.
    """
    self.config = configdict.ConfigDict(self.default_config())
    if config_overrides is not None:
      self.config.update(config_overrides)

  @staticmethod
  def default_config():
    """Returns the default configuration as a ConfigDict or Python dict."""
    return {}

  @abc.abstractmethod
  def build(self, batch_size):
    """Builds the dataset input pipeline.

    Args:
      batch_size: The number of input examples in each batch.

    Returns:
      A tf.data.Dataset object.
    """
    raise NotImplementedError


@six.add_metaclass(abc.ABCMeta)
class _ShardedDatasetBuilder(DatasetBuilder):
  """Abstract base class for a dataset consisting of sharded files."""

  def __init__(self, file_pattern, mode, config_overrides=None, use_tpu=False):
    """Initializes the dataset builder.

    Args:
      file_pattern: File pattern matching input file shards, e.g.
        "/tmp/train-?????-of-00100". May also be a comma-separated list of file
        patterns.
      mode: A tf.estimator.ModeKeys.
      config_overrides: Dict or ConfigDict containing overrides to the default
        configuration.
      use_tpu: Whether to build the dataset for TPU.
    """
    super(_ShardedDatasetBuilder, self).__init__(config_overrides)
    self.file_pattern = file_pattern
    self.mode = mode
    self.use_tpu = use_tpu

  @staticmethod
  def default_config():
    config = super(_ShardedDatasetBuilder,
                   _ShardedDatasetBuilder).default_config()
    config.update({
        "max_length": 1024,
        "shuffle_values_buffer": 1000,
        "num_parallel_parser_calls": 4,
        "batches_buffer_size": None,  # Defaults to max(1, 256 / batch_size).
    })
    return config

  @abc.abstractmethod
  def file_reader(self):
    """Returns a function that reads a single sharded file."""
    raise NotImplementedError

  @abc.abstractmethod
  def create_example_parser(self):
    """Returns a function that parses a single tf.Example proto."""
    raise NotImplementedError

  def _batch_and_pad(self, dataset, batch_size):
    """Combines elements into batches of the same length, padding if needed."""
    if self.use_tpu:
      padded_length = self.config.max_length
      if not padded_length:
        raise ValueError("config.max_length is required when using TPU")
      # Pad with zeros up to padded_length. Note that this will pad the
      # "weights" Tensor with zeros as well, which ensures that padded elements
      # do not contribute to the loss.
      padded_shapes = {}
      for name, shape in dataset.output_shapes.iteritems():
        shape.assert_is_compatible_with([None, None])  # Expect a 2D sequence.
        dims = shape.as_list()
        dims[0] = padded_length
        shape = tf.TensorShape(dims)
        shape.assert_is_fully_defined()
        padded_shapes[name] = shape
    else:
      # Pad each batch up to the maximum size of each dimension in the batch.
      padded_shapes = dataset.output_shapes

    return dataset.padded_batch(batch_size, padded_shapes)

  def build(self, batch_size):
    """Builds the dataset input pipeline.

    Args:
      batch_size:

    Returns:
      A tf.data.Dataset.

    Raises:
      ValueError: If no files match self.file_pattern.
    """
    file_patterns = self.file_pattern.split(",")
    filenames = []
    for p in file_patterns:
      matches = tf.gfile.Glob(p)
      if not matches:
        raise ValueError("Found no input files matching {}".format(p))
      filenames.extend(matches)
    tf.logging.info(
        "Building input pipeline from %d files matching patterns: %s",
        len(filenames), file_patterns)

    is_training = self.mode == tf.estimator.ModeKeys.TRAIN

    # Create a string dataset of filenames, and possibly shuffle.
    filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
    if is_training and len(filenames) > 1:
      filename_dataset = filename_dataset.shuffle(len(filenames))

    # Read serialized Example protos.
    dataset = filename_dataset.apply(
        tf.contrib.data.parallel_interleave(
            self.file_reader(), cycle_length=8, block_length=8, sloppy=True))

    if is_training:
      # Shuffle and repeat. Note that shuffle() is before repeat(), so elements
      # are shuffled among each epoch of data, and not between epochs of data.
      if self.config.shuffle_values_buffer > 0:
        dataset = dataset.shuffle(self.config.shuffle_values_buffer)
      dataset = dataset.repeat()

    # Map the parser over the dataset.
    dataset = dataset.map(
        self.create_example_parser(),
        num_parallel_calls=self.config.num_parallel_parser_calls)

    def _prepare_wavenet_inputs(features):
      """Validates features, and clips lengths and adds weights if needed."""
      # Validate feature names.
      required_features = {"autoregressive_input", "conditioning_stack"}
      allowed_features = required_features | {"weights"}
      feature_names = features.keys()
      if not required_features.issubset(feature_names):
        raise ValueError("Features must contain all of: {}. Got: {}".format(
            required_features, feature_names))
      if not allowed_features.issuperset(feature_names):
        raise ValueError("Features can only contain: {}. Got: {}".format(
            allowed_features, feature_names))

      output = {}
      for name, value in features.items():
        # Validate shapes. The output dimension is [num_samples, dim].
        ndims = len(value.shape)
        if ndims == 1:
          # Add an extra dimension: [num_samples] -> [num_samples, 1].
          value = tf.expand_dims(value, -1)
        elif ndims != 2:
          raise ValueError(
              "Features should be 1D or 2D sequences. Got '{}' = {}".format(
                  name, value))
        if self.config.max_length:
          value = value[:self.config.max_length]
        output[name] = value

      if "weights" not in output:
        output["weights"] = tf.ones_like(output["autoregressive_input"])

      return output

    dataset = dataset.map(_prepare_wavenet_inputs)

    # Batch results by up to batch_size.
    dataset = self._batch_and_pad(dataset, batch_size)

    if is_training:
      # The dataset repeats infinitely before batching, so each batch has the
      # maximum number of elements.
      dataset = dataset_ops.set_batch_size(dataset, batch_size)
    elif self.use_tpu and self.mode == tf.estimator.ModeKeys.EVAL:
      # Pad to ensure that each batch has the same number of elements.
      dataset = dataset_ops.pad_dataset_to_batch_size(dataset, batch_size)

    # Prefetch batches.
    buffer_size = (
        self.config.batches_buffer_size or max(1, int(256 / batch_size)))
    dataset = dataset.prefetch(buffer_size)

    return dataset


def tfrecord_reader(filename):
  """Returns a tf.data.Dataset that reads a single TFRecord file shard."""
  return tf.data.TFRecordDataset(filename, buffer_size=16 * 1000 * 1000)


class TFRecordDataset(_ShardedDatasetBuilder):
  """Builder for a dataset consisting of TFRecord files."""

  def file_reader(self):
    """Returns a function that reads a single file shard."""
    return tfrecord_reader