run_squad.py 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in tf2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

29
30
31
32
33
34
35
36
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import bert_models
from official.nlp import optimization
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
37
38
39
40
# word-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp
41
from official.nlp.bert import tokenization
42
from official.utils.misc import distribution_utils
43
from official.utils.misc import keras_utils
44
from official.utils.misc import tpu_lib
45

Hongkun Yu's avatar
Hongkun Yu committed
46
flags.DEFINE_enum(
Hongkun Yu's avatar
Hongkun Yu committed
47
48
49
50
51
    'mode', 'train_and_predict',
    ['train_and_predict', 'train', 'predict', 'export_only'],
    'One of {"train_and_predict", "train", "predict", "export_only"}. '
    '`train_and_predict`: both train and predict to a json file. '
    '`train`: only trains the model. '
Hongkun Yu's avatar
Hongkun Yu committed
52
53
54
    '`predict`: predict answers from the squad json file. '
    '`export_only`: will take the latest checkpoint inside '
    'model_dir and export a `SavedModel`.')
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
flags.DEFINE_string('train_data_path', '',
                    'Training data path with train tfrecords.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
                    'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
                    'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
    'do_lower_case', True,
    'Whether to lower case the input text. Should be True for uncased '
    'models and False for cased models.')
Chen Chen's avatar
Chen Chen committed
72
73
74
75
flags.DEFINE_float(
    'null_score_diff_threshold', 0.0,
    'If null_score - best_non_null is greater than the threshold, '
    'predict null. This is only used for SQuAD v2.')
76
77
78
79
80
81
82
83
84
85
86
87
88
89
flags.DEFINE_bool(
    'verbose_logging', False,
    'If true, all of the warnings related to data processing will be printed. '
    'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
                     'Total batch size for prediction.')
flags.DEFINE_integer(
    'n_best_size', 20,
    'The total number of n-best predictions to generate in the '
    'nbest_predictions.json output file.')
flags.DEFINE_integer(
    'max_answer_length', 30,
    'The maximum length of an answer that can be generated. This is needed '
    'because the start and end predictions are not conditioned on one another.')
90
91
92
93
94
flags.DEFINE_string(
    'sp_model_file', None,
    'The path to the sentence piece model. Used by sentence piece tokenizer '
    'employed by ALBERT.')

95

96
97
common_flags.define_common_bert_flags()

98
99
FLAGS = flags.FLAGS

100
101
102
103
104
105
MODEL_CLASSES = {
    'bert': (modeling.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
    'albert': (modeling.AlbertConfig, squad_lib_sp,
               tokenization.FullSentencePieceTokenizer),
}

106
107
108
109
110

def squad_loss_fn(start_positions,
                  end_positions,
                  start_logits,
                  end_logits,
111
                  loss_factor=1.0):
112
113
114
115
116
117
118
  """Returns sparse categorical crossentropy for start/end logits."""
  start_loss = tf.keras.backend.sparse_categorical_crossentropy(
      start_positions, start_logits, from_logits=True)
  end_loss = tf.keras.backend.sparse_categorical_crossentropy(
      end_positions, end_logits, from_logits=True)

  total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
119
  total_loss *= loss_factor
120
121
122
  return total_loss


123
def get_loss_fn(loss_factor=1.0):
124
125
126
127
128
  """Gets a loss function for squad task."""

  def _loss_fn(labels, model_outputs):
    start_positions = labels['start_positions']
    end_positions = labels['end_positions']
129
    start_logits, end_logits = model_outputs
130
131
132
133
134
    return squad_loss_fn(
        start_positions,
        end_positions,
        start_logits,
        end_logits,
135
        loss_factor=loss_factor)
136
137
138
139
140
141

  return _loss_fn


def get_raw_results(predictions):
  """Converts multi-replica predictions to RawResult."""
142
  squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
143
144
145
146
147
148
149
150
151
152
153
  for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
                                                  predictions['start_logits'],
                                                  predictions['end_logits']):
    for values in zip(unique_ids.numpy(), start_logits.numpy(),
                      end_logits.numpy()):
      yield squad_lib.RawResult(
          unique_id=values[0],
          start_logits=values[1].tolist(),
          end_logits=values[2].tolist())


Hongkun Yu's avatar
Hongkun Yu committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
                   is_training):
  """Gets a closure to create a dataset.."""

  def _dataset_fn(ctx=None):
    """Returns tf.data.Dataset for distributed BERT pretraining."""
    batch_size = ctx.get_per_replica_batch_size(
        global_batch_size) if ctx else global_batch_size
    dataset = input_pipeline.create_squad_dataset(
        input_file_pattern,
        max_seq_length,
        batch_size,
        is_training=is_training,
        input_pipeline_context=ctx)
    return dataset

  return _dataset_fn


173
174
175
def predict_squad_customized(strategy, input_meta_data, bert_config,
                             predict_tfrecord_path, num_steps):
  """Make predictions using a Bert-based squad model."""
Hongkun Yu's avatar
Hongkun Yu committed
176
  predict_dataset_fn = get_dataset_fn(
177
178
179
180
181
      predict_tfrecord_path,
      input_meta_data['max_seq_length'],
      FLAGS.predict_batch_size,
      is_training=False)
  predict_iterator = iter(
Hongkun Yu's avatar
Hongkun Yu committed
182
183
      strategy.experimental_distribute_datasets_from_function(
          predict_dataset_fn))
184
185
186
187
188

  with strategy.scope():
    # Prediction always uses float32, even if training uses mixed precision.
    tf.keras.mixed_precision.experimental.set_policy('float32')
    squad_model, _ = bert_models.squad_model(
189
        bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
190
191
192
193
194
195
196
197
198
199
200
201
202

  checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
  logging.info('Restoring checkpoints from %s', checkpoint_path)
  checkpoint = tf.train.Checkpoint(model=squad_model)
  checkpoint.restore(checkpoint_path).expect_partial()

  @tf.function
  def predict_step(iterator):
    """Predicts on distributed devices."""

    def _replicated_step(inputs):
      """Replicated prediction calculation."""
      x, _ = inputs
203
204
      unique_ids = x.pop('unique_ids')
      start_logits, end_logits = squad_model(x, training=False)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
      return dict(
          unique_ids=unique_ids,
          start_logits=start_logits,
          end_logits=end_logits)

    outputs = strategy.experimental_run_v2(
        _replicated_step, args=(next(iterator),))
    return tf.nest.map_structure(strategy.experimental_local_results, outputs)

  all_results = []
  for _ in range(num_steps):
    predictions = predict_step(predict_iterator)
    for result in get_raw_results(predictions):
      all_results.append(result)
    if len(all_results) % 100 == 0:
      logging.info('Made predictions for %d records.', len(all_results))
  return all_results
222
223


224
225
226
227
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
228
  """Run bert squad training."""
229
230
231
  if strategy:
    logging.info('Training using customized training loop with distribution'
                 ' strategy.')
232
233
  # Enables XLA in Session Config. Should not be set for TPU.
  keras_utils.set_config_v2(FLAGS.enable_xla)
234

235
236
  use_float16 = common_flags.use_float16()
  if use_float16:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
237
    tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
238

239
240
  bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
      FLAGS.bert_config_file)
241
242
243
244
245
  epochs = FLAGS.num_train_epochs
  num_train_examples = input_meta_data['train_data_size']
  max_seq_length = input_meta_data['max_seq_length']
  steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
  warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
Hongkun Yu's avatar
Hongkun Yu committed
246
  train_input_fn = get_dataset_fn(
247
248
249
250
251
252
      FLAGS.train_data_path,
      max_seq_length,
      FLAGS.train_batch_size,
      is_training=True)

  def _get_squad_model():
253
    """Get Squad model and optimizer."""
254
    squad_model, core_model = bert_models.squad_model(
255
256
        bert_config,
        max_seq_length,
Hongkun Yu's avatar
Hongkun Yu committed
257
        float_type=tf.float16 if use_float16 else tf.float32,
Chen Chen's avatar
Chen Chen committed
258
        hub_module_url=FLAGS.hub_module_url)
259
260
    squad_model.optimizer = optimization.create_optimizer(
        FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
261
    if use_float16:
Reed Wanderman-Milne's avatar
Reed Wanderman-Milne committed
262
263
264
      # Wraps optimizer with a LossScaleOptimizer. This is done automatically
      # in compile() with the "mixed_float16" policy, but since we do not call
      # compile(), we must wrap the optimizer manually.
265
266
267
      squad_model.optimizer = (
          tf.keras.mixed_precision.experimental.LossScaleOptimizer(
              squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
268
269
270
271
272
273
274
    if FLAGS.fp16_implementation == 'graph_rewrite':
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
275
          squad_model.optimizer)
276
277
278
279
280
281
    return squad_model, core_model

  # The original BERT model does not scale the loss by
  # 1/num_replicas_in_sync. It could be an accident. So, in order to use
  # the same hyper parameter, we do the same thing here by keeping each
  # replica loss as it is.
282
283
284
  loss_fn = get_loss_fn(
      loss_factor=1.0 /
      strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
285
286
287
288
289
290
291

  model_training_utils.run_customized_training_loop(
      strategy=strategy,
      model_fn=_get_squad_model,
      loss_fn=loss_fn,
      model_dir=FLAGS.model_dir,
      steps_per_epoch=steps_per_epoch,
292
      steps_per_loop=FLAGS.steps_per_loop,
293
294
295
      epochs=epochs,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
296
      run_eagerly=run_eagerly,
davidmochen's avatar
davidmochen committed
297
      custom_callbacks=custom_callbacks)
298
299
300
301


def predict_squad(strategy, input_meta_data):
  """Makes predictions for a squad dataset."""
302
303
304
305
306
307
308
309
  config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type]
  bert_config = config_cls.from_json_file(FLAGS.bert_config_file)
  if tokenizer_cls == tokenization.FullTokenizer:
    tokenizer = tokenizer_cls(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  else:
    assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
    tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  doc_stride = input_meta_data['doc_stride']
  max_query_length = input_meta_data['max_query_length']
  # Whether data should be in Ver 2.0 format.
  version_2_with_negative = input_meta_data.get('version_2_with_negative',
                                                False)
  eval_examples = squad_lib.read_squad_examples(
      input_file=FLAGS.predict_file,
      is_training=False,
      version_2_with_negative=version_2_with_negative)

  eval_writer = squad_lib.FeatureWriter(
      filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
      is_training=False)
  eval_features = []

  def _append_feature(feature, is_padding):
    if not is_padding:
      eval_features.append(feature)
    eval_writer.process_feature(feature)

  # TPU requires a fixed batch size for all batches, therefore the number
  # of examples must be a multiple of the batch size, or else examples
  # will get dropped. So we pad with fake examples which are ignored
  # later on.
334
  kwargs = dict(
335
336
337
338
339
340
341
342
      examples=eval_examples,
      tokenizer=tokenizer,
      max_seq_length=input_meta_data['max_seq_length'],
      doc_stride=doc_stride,
      max_query_length=max_query_length,
      is_training=False,
      output_fn=_append_feature,
      batch_size=FLAGS.predict_batch_size)
343
344
345
346
347

  # squad_lib_sp requires one more argument 'do_lower_case'.
  if squad_lib == squad_lib_sp:
    kwargs['do_lower_case'] = FLAGS.do_lower_case
  dataset_size = squad_lib.convert_examples_to_features(**kwargs)
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
  eval_writer.close()

  logging.info('***** Running predictions *****')
  logging.info('  Num orig examples = %d', len(eval_examples))
  logging.info('  Num split examples = %d', len(eval_features))
  logging.info('  Batch size = %d', FLAGS.predict_batch_size)

  num_steps = int(dataset_size / FLAGS.predict_batch_size)
  all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
                                         eval_writer.filename, num_steps)

  output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
  output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
  output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')

  squad_lib.write_predictions(
      eval_examples,
      eval_features,
      all_results,
      FLAGS.n_best_size,
      FLAGS.max_answer_length,
      FLAGS.do_lower_case,
      output_prediction_file,
      output_nbest_file,
      output_null_log_odds_file,
Chen Chen's avatar
Chen Chen committed
373
374
      version_2_with_negative=version_2_with_negative,
      null_score_diff_threshold=FLAGS.null_score_diff_threshold,
375
376
377
      verbose=FLAGS.verbose_logging)


Hongkun Yu's avatar
Hongkun Yu committed
378
379
380
381
382
383
384
385
386
387
388
389
def export_squad(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
390
391
  bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
      FLAGS.bert_config_file)
Hongkun Yu's avatar
Hongkun Yu committed
392
  squad_model, _ = bert_models.squad_model(
393
      bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
Hongkun Yu's avatar
Hongkun Yu committed
394
395
396
397
  model_saving_utils.export_bert_model(
      model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)


398
399
400
def main(_):
  # Users should always run this script under TF 2.x
  assert tf.version.VERSION.startswith('2.')
401

402
403
404
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

Hongkun Yu's avatar
Hongkun Yu committed
405
406
407
408
  if FLAGS.mode == 'export_only':
    export_squad(FLAGS.model_export_path, input_meta_data)
    return

409
410
411
412
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
Hongkun Yu's avatar
Hongkun Yu committed
413
  if FLAGS.mode in ('train', 'train_and_predict'):
414
    train_squad(strategy, input_meta_data)
Hongkun Yu's avatar
Hongkun Yu committed
415
  if FLAGS.mode in ('predict', 'train_and_predict'):
416
417
418
419
420
421
422
    predict_squad(strategy, input_meta_data)


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('model_dir')
  app.run(main)