input_pipeline.py 11.4 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
16
17
18
19
"""BERT model input pipelines."""

import tensorflow as tf


20
21
22
23
24
25
26
27
28
29
30
def decode_record(record, name_to_features):
  """Decodes a record to a TensorFlow example."""
  example = tf.io.parse_single_example(record, name_to_features)

  # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  # So cast all int64 to int32.
  for name in list(example.keys()):
    t = example[name]
    if t.dtype == tf.int64:
      t = tf.cast(t, tf.int32)
    example[name] = t
31

32
  return example
33
34


35
def single_file_dataset(input_file, name_to_features, num_samples=None):
Hongkun Yu's avatar
Hongkun Yu committed
36
37
38
39
  """Creates a single-file dataset to be passed for BERT custom training."""
  # For training, we want a lot of parallel reading and shuffling.
  # For eval, we want no shuffling and parallel reading doesn't matter.
  d = tf.data.TFRecordDataset(input_file)
40
41
  if num_samples:
    d = d.take(num_samples)
Chen Chen's avatar
Chen Chen committed
42
43
44
  d = d.map(
      lambda record: decode_record(record, name_to_features),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
Hongkun Yu's avatar
Hongkun Yu committed
45
46
47
48
49
50
51
52
53
54

  # When `input_file` is a path to a single file or a list
  # containing a single path, disable auto sharding so that
  # same input file is sent to all workers.
  if isinstance(input_file, str) or len(input_file) == 1:
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = (
        tf.data.experimental.AutoShardPolicy.OFF)
    d = d.with_options(options)
  return d
55
56


57
def create_pretrain_dataset(input_patterns,
58
59
60
                            seq_length,
                            max_predictions_per_seq,
                            batch_size,
61
                            is_training=True,
62
                            input_pipeline_context=None,
Chen Chen's avatar
Chen Chen committed
63
                            use_next_sentence_label=True,
Hongkun Yu's avatar
Hongkun Yu committed
64
65
                            use_position_id=False,
                            output_fake_labels=True):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  """Creates input dataset from (tf)records files for pretraining."""
  name_to_features = {
      'input_ids':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'masked_lm_positions':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
      'masked_lm_ids':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
      'masked_lm_weights':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
  }
81
82
83
  if use_next_sentence_label:
    name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
                                                                     tf.int64)
Chen Chen's avatar
Chen Chen committed
84
85
86
  if use_position_id:
    name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
                                                             tf.int64)
Chen Chen's avatar
Chen Chen committed
87
88
89
90
  for input_pattern in input_patterns:
    if not tf.io.gfile.glob(input_pattern):
      raise ValueError('%s does not match any files.' % input_pattern)

91
  dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
92
93
94
95

  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
Chen Chen's avatar
Chen Chen committed
96
97
  if is_training:
    dataset = dataset.repeat()
98

Chen Chen's avatar
Chen Chen committed
99
100
101
102
103
104
    # We set shuffle buffer to exactly match total number of
    # training files to ensure that training data is well shuffled.
    input_files = []
    for input_pattern in input_patterns:
      input_files.extend(tf.io.gfile.glob(input_pattern))
    dataset = dataset.shuffle(len(input_files))
105
106

  # In parallel, create tf record dataset for each train files.
Jing Li's avatar
Jing Li committed
107
108
109
  # cycle_length = 8 means that up to 8 files will be read and deserialized in
  # parallel. You may want to increase this number if you have a large number of
  # CPU cores.
110
  dataset = dataset.interleave(
Chen Chen's avatar
Chen Chen committed
111
112
      tf.data.TFRecordDataset,
      cycle_length=8,
Jing Li's avatar
Jing Li committed
113
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
114

Chen Chen's avatar
Chen Chen committed
115
116
117
  if is_training:
    dataset = dataset.shuffle(100)

118
119
120
  decode_fn = lambda record: decode_record(record, name_to_features)
  dataset = dataset.map(
      decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
121
122
123
124
125
126
127
128
129
130
131

  def _select_data_from_record(record):
    """Filter out features to use for pretraining."""
    x = {
        'input_word_ids': record['input_ids'],
        'input_mask': record['input_mask'],
        'input_type_ids': record['segment_ids'],
        'masked_lm_positions': record['masked_lm_positions'],
        'masked_lm_ids': record['masked_lm_ids'],
        'masked_lm_weights': record['masked_lm_weights'],
    }
132
133
    if use_next_sentence_label:
      x['next_sentence_labels'] = record['next_sentence_labels']
Chen Chen's avatar
Chen Chen committed
134
135
    if use_position_id:
      x['position_ids'] = record['position_ids']
136

Hongkun Yu's avatar
Hongkun Yu committed
137
138
139
140
141
    # TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
    if output_fake_labels:
      return (x, record['masked_lm_weights'])
    else:
      return x
142

143
144
145
  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
Chen Chen's avatar
Chen Chen committed
146
  dataset = dataset.batch(batch_size, drop_remainder=is_training)
Chen Chen's avatar
Chen Chen committed
147
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
148
149
150
151
152
153
154
  return dataset


def create_classifier_dataset(file_path,
                              seq_length,
                              batch_size,
                              is_training=True,
155
                              input_pipeline_context=None,
156
                              label_type=tf.int64,
157
158
                              include_sample_weights=False,
                              num_samples=None):
159
160
161
162
163
  """Creates input dataset from (tf)records files for train/eval."""
  name_to_features = {
      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
164
      'label_ids': tf.io.FixedLenFeature([], label_type),
165
  }
166
167
  if include_sample_weights:
    name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
168
169
  dataset = single_file_dataset(file_path, name_to_features,
                                num_samples=num_samples)
Hongkun Yu's avatar
Hongkun Yu committed
170
171
172
173
174
175

  # The dataset is always sharded by number of hosts.
  # num_input_pipelines is the number of hosts rather than number of cores.
  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
176
177
178
179
180
181
182
183

  def _select_data_from_record(record):
    x = {
        'input_word_ids': record['input_ids'],
        'input_mask': record['input_mask'],
        'input_type_ids': record['segment_ids']
    }
    y = record['label_ids']
184
185
186
    if include_sample_weights:
      w = record['weight']
      return (x, y, w)
187
188
189
190
191
192
    return (x, y)

  if is_training:
    dataset = dataset.shuffle(100)
    dataset = dataset.repeat()

Chen Chen's avatar
Chen Chen committed
193
194
195
  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
Hongkun Yu's avatar
Hongkun Yu committed
196
  dataset = dataset.batch(batch_size, drop_remainder=is_training)
Chen Chen's avatar
Chen Chen committed
197
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
198
199
200
  return dataset


Hongkun Yu's avatar
Hongkun Yu committed
201
202
203
204
205
def create_squad_dataset(file_path,
                         seq_length,
                         batch_size,
                         is_training=True,
                         input_pipeline_context=None):
206
207
208
209
210
211
212
213
214
  """Creates input dataset from (tf)records files for train/eval."""
  name_to_features = {
      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
  }
  if is_training:
    name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
215
216
  else:
    name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
217

Hongkun Yu's avatar
Hongkun Yu committed
218
219
220
221
222
223
224
  dataset = single_file_dataset(file_path, name_to_features)

  # The dataset is always sharded by number of hosts.
  # num_input_pipelines is the number of hosts rather than number of cores.
  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
225
226

  def _select_data_from_record(record):
227
    """Dispatches record to features and labels."""
228
229
230
231
    x, y = {}, {}
    for name, tensor in record.items():
      if name in ('start_positions', 'end_positions'):
        y[name] = tensor
232
233
234
235
      elif name == 'input_ids':
        x['input_word_ids'] = tensor
      elif name == 'segment_ids':
        x['input_type_ids'] = tensor
236
237
238
239
240
241
242
243
      else:
        x[name] = tensor
    return (x, y)

  if is_training:
    dataset = dataset.shuffle(100)
    dataset = dataset.repeat()

Chen Chen's avatar
Chen Chen committed
244
245
246
  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
247
  dataset = dataset.batch(batch_size, drop_remainder=True)
Chen Chen's avatar
Chen Chen committed
248
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
249
  return dataset
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
250
251
252
253
254
255
256
257
258
259
260


def create_retrieval_dataset(file_path,
                             seq_length,
                             batch_size,
                             input_pipeline_context=None):
  """Creates input dataset from (tf)records files for scoring."""
  name_to_features = {
      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
Chen Chen's avatar
Chen Chen committed
261
      'example_id': tf.io.FixedLenFeature([1], tf.int64),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
  }
  dataset = single_file_dataset(file_path, name_to_features)

  # The dataset is always sharded by number of hosts.
  # num_input_pipelines is the number of hosts rather than number of cores.
  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)

  def _select_data_from_record(record):
    x = {
        'input_word_ids': record['input_ids'],
        'input_mask': record['input_mask'],
        'input_type_ids': record['segment_ids']
    }
Chen Chen's avatar
Chen Chen committed
277
    y = record['example_id']
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
278
279
280
281
282
283
    return (x, y)

  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size, drop_remainder=False)
Chen Chen's avatar
Chen Chen committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

  def _pad_to_batch(x, y):
    cur_size = tf.shape(y)[0]
    pad_size = batch_size - cur_size

    pad_ids = tf.zeros(shape=[pad_size, seq_length], dtype=tf.int32)
    for key in ('input_word_ids', 'input_mask', 'input_type_ids'):
      x[key] = tf.concat([x[key], pad_ids], axis=0)

    pad_labels = -tf.ones(shape=[pad_size, 1], dtype=tf.int32)
    y = tf.concat([y, pad_labels], axis=0)
    return x, y

  dataset = dataset.map(
      _pad_to_batch,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
301
302
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset