run_squad.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in tf2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import json
import os

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

30
# Import BERT model libraries.
31
from official.bert import bert_models
32
from official.bert import common_flags
33
34
35
36
37
38
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import squad_lib
from official.bert import tokenization
39
from official.utils.misc import keras_utils
40
from official.utils.misc import tpu_lib
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

flags.DEFINE_bool('do_train', False, 'Whether to run training.')
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
flags.DEFINE_string('train_data_path', '',
                    'Training data path with train tfrecords.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
                    'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
                    'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
    'do_lower_case', True,
    'Whether to lower case the input text. Should be True for uncased '
    'models and False for cased models.')
flags.DEFINE_bool(
    'verbose_logging', False,
    'If true, all of the warnings related to data processing will be printed. '
    'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
                     'Total batch size for prediction.')
flags.DEFINE_integer(
    'n_best_size', 20,
    'The total number of n-best predictions to generate in the '
    'nbest_predictions.json output file.')
flags.DEFINE_integer(
    'max_answer_length', 30,
    'The maximum length of an answer that can be generated. This is needed '
    'because the start and end predictions are not conditioned on one another.')

76
77
common_flags.define_common_bert_flags()

78
79
80
81
82
83
84
FLAGS = flags.FLAGS


def squad_loss_fn(start_positions,
                  end_positions,
                  start_logits,
                  end_logits,
85
                  loss_factor=1.0):
86
87
88
89
90
91
92
  """Returns sparse categorical crossentropy for start/end logits."""
  start_loss = tf.keras.backend.sparse_categorical_crossentropy(
      start_positions, start_logits, from_logits=True)
  end_loss = tf.keras.backend.sparse_categorical_crossentropy(
      end_positions, end_logits, from_logits=True)

  total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
93
  total_loss *= loss_factor
94
95
96
  return total_loss


97
def get_loss_fn(loss_factor=1.0):
98
99
100
101
102
103
104
105
106
107
108
  """Gets a loss function for squad task."""

  def _loss_fn(labels, model_outputs):
    start_positions = labels['start_positions']
    end_positions = labels['end_positions']
    _, start_logits, end_logits = model_outputs
    return squad_loss_fn(
        start_positions,
        end_positions,
        start_logits,
        end_logits,
109
        loss_factor=loss_factor)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

  return _loss_fn


def get_raw_results(predictions):
  """Converts multi-replica predictions to RawResult."""
  for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
                                                  predictions['start_logits'],
                                                  predictions['end_logits']):
    for values in zip(unique_ids.numpy(), start_logits.numpy(),
                      end_logits.numpy()):
      yield squad_lib.RawResult(
          unique_id=values[0],
          start_logits=values[1].tolist(),
          end_logits=values[2].tolist())


def predict_squad_customized(strategy, input_meta_data, bert_config,
                             predict_tfrecord_path, num_steps):
  """Make predictions using a Bert-based squad model."""
  primary_cpu_task = '/job:worker' if FLAGS.tpu else ''

  with tf.device(primary_cpu_task):
    predict_dataset = input_pipeline.create_squad_dataset(
        predict_tfrecord_path,
        input_meta_data['max_seq_length'],
        FLAGS.predict_batch_size,
        is_training=False)
138
139
    predict_iterator = iter(
        strategy.experimental_distribute_dataset(predict_dataset))
140
141

    with strategy.scope():
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
142
143
      # Prediction always uses float32, even if training uses mixed precision.
      tf.keras.mixed_precision.experimental.set_policy('float32')
144
145
146
147
148
149
      squad_model, _ = bert_models.squad_model(
          bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    logging.info('Restoring checkpoints from %s', checkpoint_path)
    checkpoint = tf.train.Checkpoint(model=squad_model)
150
    checkpoint.restore(checkpoint_path).expect_partial()
151
152
153
154
155

    @tf.function
    def predict_step(iterator):
      """Predicts on distributed devices."""

156
      def _replicated_step(inputs):
157
158
159
160
161
162
163
164
        """Replicated prediction calculation."""
        x, _ = inputs
        unique_ids, start_logits, end_logits = squad_model(x, training=False)
        return dict(
            unique_ids=unique_ids,
            start_logits=start_logits,
            end_logits=end_logits)

165
166
      outputs = strategy.experimental_run_v2(
          _replicated_step, args=(next(iterator),))
167
      return tf.nest.map_structure(strategy.experimental_local_results, outputs)
168
169
170
171
172
173
174
175
176
177
178

    all_results = []
    for _ in range(num_steps):
      predictions = predict_step(predict_iterator)
      for result in get_raw_results(predictions):
        all_results.append(result)
      if len(all_results) % 100 == 0:
        logging.info('Made predictions for %d records.', len(all_results))
    return all_results


179
180
181
182
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
183
  """Run bert squad training."""
184
185
186
  if strategy:
    logging.info('Training using customized training loop with distribution'
                 ' strategy.')
187
188
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)
189

190
191
  use_float16 = common_flags.use_float16()
  if use_float16:
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
192
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
193
194
    tf.keras.mixed_precision.experimental.set_policy(policy)

195
196
197
198
199
200
201
202
203
204
205
206
207
208
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  epochs = FLAGS.num_train_epochs
  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
  warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
  train_input_fn = functools.partial(
      input_pipeline.create_squad_dataset,
      FLAGS.train_data_path,
      max_seq_length,
      FLAGS.train_batch_size,
      is_training=True)

  def _get_squad_model():
209
    """Get Squad model and optimizer."""
210
    squad_model, core_model = bert_models.squad_model(
211
212
213
        bert_config,
        max_seq_length,
        float_type=tf.float16 if use_float16 else tf.float32)
214
215
    squad_model.optimizer = optimization.create_optimizer(
        FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
216
    if use_float16:
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
217
218
219
      # Wraps optimizer with a LossScaleOptimizer. This is done automatically
      # in compile() with the "mixed_float16" policy, but since we do not call
      # compile(), we must wrap the optimizer manually.
220
221
222
      squad_model.optimizer = (
          tf.keras.mixed_precision.experimental.LossScaleOptimizer(
              squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
223
224
225
226
227
228
    return squad_model, core_model

  # The original BERT model does not scale the loss by
  # 1/num_replicas_in_sync. It could be an accident. So, in order to use
  # the same hyper parameter, we do the same thing here by keeping each
  # replica loss as it is.
229
  loss_fn = get_loss_fn(loss_factor=1.0)
230
231
232
233
234
235
236
237
  use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)

  model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_squad_model,
      loss_fn=loss_fn,
      model_dir=FLAGS.model_dir,
      steps_per_epoch=steps_per_epoch,
238
      steps_per_loop=FLAGS.steps_per_loop,
239
240
241
      epochs=epochs,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
davidmochen's avatar
davidmochen committed
242
      use_remote_tpu=use_remote_tpu,
243
      run_eagerly=run_eagerly,
davidmochen's avatar
davidmochen committed
244
      custom_callbacks=custom_callbacks)
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316


def predict_squad(strategy, input_meta_data):
  """Makes predictions for a squad dataset."""
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  doc_stride = input_meta_data['doc_stride']
  max_query_length = input_meta_data['max_query_length']
  # Whether data should be in Ver 2.0 format.
  version_2_with_negative = input_meta_data.get('version_2_with_negative',
                                                False)
  eval_examples = squad_lib.read_squad_examples(
      input_file=FLAGS.predict_file,
      is_training=False,
      version_2_with_negative=version_2_with_negative)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  eval_writer = squad_lib.FeatureWriter(
      filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
      is_training=False)
  eval_features = []

  def _append_feature(feature, is_padding):
    if not is_padding:
      eval_features.append(feature)
    eval_writer.process_feature(feature)

  # TPU requires a fixed batch size for all batches, therefore the number
  # of examples must be a multiple of the batch size, or else examples
  # will get dropped. So we pad with fake examples which are ignored
  # later on.
  dataset_size = squad_lib.convert_examples_to_features(
      examples=eval_examples,
      tokenizer=tokenizer,
      max_seq_length=input_meta_data['max_seq_length'],
      doc_stride=doc_stride,
      max_query_length=max_query_length,
      is_training=False,
      output_fn=_append_feature,
      batch_size=FLAGS.predict_batch_size)
  eval_writer.close()

  logging.info('***** Running predictions *****')
  logging.info('  Num orig examples = %d', len(eval_examples))
  logging.info('  Num split examples = %d', len(eval_features))
  logging.info('  Batch size = %d', FLAGS.predict_batch_size)

  num_steps = int(dataset_size / FLAGS.predict_batch_size)
  all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
                                         eval_writer.filename, num_steps)

  output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
  output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
  output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')

  squad_lib.write_predictions(
      eval_examples,
      eval_features,
      all_results,
      FLAGS.n_best_size,
      FLAGS.max_answer_length,
      FLAGS.do_lower_case,
      output_prediction_file,
      output_nbest_file,
      output_null_log_odds_file,
      verbose=FLAGS.verbose_logging)


def main(_):
  # Users should always run this script under TF 2.x
  assert tf.version.VERSION.startswith('2.')
317

318
319
320
321
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  strategy = None
322
323
  if FLAGS.strategy_type == 'mirror':
    strategy = tf.distribute.MirroredStrategy()
324
325
  elif FLAGS.strategy_type == 'multi_worker_mirror':
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
326
327
328
  elif FLAGS.strategy_type == 'tpu':
    # Initialize TPU System.
    cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
329
    strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
330
331
332
333
334
335
336
337
338
339
340
341
342
  else:
    raise ValueError('The distribution strategy type is not supported: %s' %
                     FLAGS.strategy_type)
  if FLAGS.do_train:
    train_squad(strategy, input_meta_data)
  if FLAGS.do_predict:
    predict_squad(strategy, input_meta_data)


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('model_dir')
  app.run(main)