input_pipeline.py 8.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT model input pipelines."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


24
25
26
27
28
29
30
31
32
33
34
def decode_record(record, name_to_features):
  """Decodes a record to a TensorFlow example."""
  example = tf.io.parse_single_example(record, name_to_features)

  # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  # So cast all int64 to int32.
  for name in list(example.keys()):
    t = example[name]
    if t.dtype == tf.int64:
      t = tf.cast(t, tf.int32)
    example[name] = t
35

36
  return example
37
38


Hongkun Yu's avatar
Hongkun Yu committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def single_file_dataset(input_file, name_to_features):
  """Creates a single-file dataset to be passed for BERT custom training."""
  # For training, we want a lot of parallel reading and shuffling.
  # For eval, we want no shuffling and parallel reading doesn't matter.
  d = tf.data.TFRecordDataset(input_file)
  d = d.map(lambda record: decode_record(record, name_to_features))

  # When `input_file` is a path to a single file or a list
  # containing a single path, disable auto sharding so that
  # same input file is sent to all workers.
  if isinstance(input_file, str) or len(input_file) == 1:
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = (
        tf.data.experimental.AutoShardPolicy.OFF)
    d = d.with_options(options)
  return d
55
56


57
def create_pretrain_dataset(input_patterns,
58
59
60
                            seq_length,
                            max_predictions_per_seq,
                            batch_size,
61
                            is_training=True,
62
                            input_pipeline_context=None,
Chen Chen's avatar
Chen Chen committed
63
64
                            use_next_sentence_label=True,
                            use_position_id=False):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  """Creates input dataset from (tf)records files for pretraining."""
  name_to_features = {
      'input_ids':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids':
          tf.io.FixedLenFeature([seq_length], tf.int64),
      'masked_lm_positions':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
      'masked_lm_ids':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
      'masked_lm_weights':
          tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
  }
80
81
82
  if use_next_sentence_label:
    name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
                                                                     tf.int64)
Chen Chen's avatar
Chen Chen committed
83
84
85
  if use_position_id:
    name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
                                                             tf.int64)
Chen Chen's avatar
Chen Chen committed
86
87
88
89
  for input_pattern in input_patterns:
    if not tf.io.gfile.glob(input_pattern):
      raise ValueError('%s does not match any files.' % input_pattern)

90
  dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
91
92
93
94

  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
Chen Chen's avatar
Chen Chen committed
95
96
  if is_training:
    dataset = dataset.repeat()
97

Chen Chen's avatar
Chen Chen committed
98
99
100
101
102
103
    # We set shuffle buffer to exactly match total number of
    # training files to ensure that training data is well shuffled.
    input_files = []
    for input_pattern in input_patterns:
      input_files.extend(tf.io.gfile.glob(input_pattern))
    dataset = dataset.shuffle(len(input_files))
104
105

  # In parallel, create tf record dataset for each train files.
Jing Li's avatar
Jing Li committed
106
107
108
  # cycle_length = 8 means that up to 8 files will be read and deserialized in
  # parallel. You may want to increase this number if you have a large number of
  # CPU cores.
109
  dataset = dataset.interleave(
Jing Li's avatar
Jing Li committed
110
111
      tf.data.TFRecordDataset, cycle_length=8,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
112
113
114
115

  decode_fn = lambda record: decode_record(record, name_to_features)
  dataset = dataset.map(
      decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
116
117
118
119
120
121
122
123
124
125
126

  def _select_data_from_record(record):
    """Filter out features to use for pretraining."""
    x = {
        'input_word_ids': record['input_ids'],
        'input_mask': record['input_mask'],
        'input_type_ids': record['segment_ids'],
        'masked_lm_positions': record['masked_lm_positions'],
        'masked_lm_ids': record['masked_lm_ids'],
        'masked_lm_weights': record['masked_lm_weights'],
    }
127
128
    if use_next_sentence_label:
      x['next_sentence_labels'] = record['next_sentence_labels']
Chen Chen's avatar
Chen Chen committed
129
130
    if use_position_id:
      x['position_ids'] = record['position_ids']
131
132
133
134
135

    y = record['masked_lm_weights']

    return (x, y)

136
137
138
  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
139
140
141
142

  if is_training:
    dataset = dataset.shuffle(100)

Chen Chen's avatar
Chen Chen committed
143
  dataset = dataset.batch(batch_size, drop_remainder=is_training)
144
145
146
147
148
149
150
151
  dataset = dataset.prefetch(1024)
  return dataset


def create_classifier_dataset(file_path,
                              seq_length,
                              batch_size,
                              is_training=True,
Hongkun Yu's avatar
Hongkun Yu committed
152
                              input_pipeline_context=None):
153
154
155
156
157
158
159
160
  """Creates input dataset from (tf)records files for train/eval."""
  name_to_features = {
      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'label_ids': tf.io.FixedLenFeature([], tf.int64),
      'is_real_example': tf.io.FixedLenFeature([], tf.int64),
  }
Hongkun Yu's avatar
Hongkun Yu committed
161
162
163
164
165
166
167
  dataset = single_file_dataset(file_path, name_to_features)

  # The dataset is always sharded by number of hosts.
  # num_input_pipelines is the number of hosts rather than number of cores.
  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

  def _select_data_from_record(record):
    x = {
        'input_word_ids': record['input_ids'],
        'input_mask': record['input_mask'],
        'input_type_ids': record['segment_ids']
    }
    y = record['label_ids']
    return (x, y)

  dataset = dataset.map(_select_data_from_record)

  if is_training:
    dataset = dataset.shuffle(100)
    dataset = dataset.repeat()

Hongkun Yu's avatar
Hongkun Yu committed
184
  dataset = dataset.batch(batch_size, drop_remainder=is_training)
185
186
187
188
  dataset = dataset.prefetch(1024)
  return dataset


Hongkun Yu's avatar
Hongkun Yu committed
189
190
191
192
193
def create_squad_dataset(file_path,
                         seq_length,
                         batch_size,
                         is_training=True,
                         input_pipeline_context=None):
194
195
196
197
198
199
200
201
202
  """Creates input dataset from (tf)records files for train/eval."""
  name_to_features = {
      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
  }
  if is_training:
    name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
203
204
  else:
    name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
205

Hongkun Yu's avatar
Hongkun Yu committed
206
207
208
209
210
211
212
  dataset = single_file_dataset(file_path, name_to_features)

  # The dataset is always sharded by number of hosts.
  # num_input_pipelines is the number of hosts rather than number of cores.
  if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
    dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
                            input_pipeline_context.input_pipeline_id)
213
214

  def _select_data_from_record(record):
215
    """Dispatches record to features and labels."""
216
217
218
219
    x, y = {}, {}
    for name, tensor in record.items():
      if name in ('start_positions', 'end_positions'):
        y[name] = tensor
220
221
222
223
      elif name == 'input_ids':
        x['input_word_ids'] = tensor
      elif name == 'segment_ids':
        x['input_type_ids'] = tensor
224
225
226
227
228
229
230
231
232
233
234
235
236
      else:
        x[name] = tensor
    return (x, y)

  dataset = dataset.map(_select_data_from_record)

  if is_training:
    dataset = dataset.shuffle(100)
    dataset = dataset.repeat()

  dataset = dataset.batch(batch_size, drop_remainder=True)
  dataset = dataset.prefetch(1024)
  return dataset