run_squad.py 5.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
16

17
18
19
20
21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
23
import os
import tempfile
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
25
import time

26
27
from absl import app
from absl import flags
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
from absl import logging
29
30
import tensorflow as tf

31
from official.nlp.bert import configs as bert_configs
Chen Chen's avatar
Chen Chen committed
32
from official.nlp.bert import run_squad_helper
33
from official.nlp.bert import tokenization
34
from official.nlp.data import squad_lib as squad_lib_wp
35
from official.utils.misc import distribution_utils
36
from official.utils.misc import keras_utils
37

Chen Chen's avatar
Chen Chen committed
38

39
40
41
flags.DEFINE_string('vocab_file', None,
                    'The vocabulary file that the BERT model was trained on.')

Chen Chen's avatar
Chen Chen committed
42
43
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
44

45
46
47
FLAGS = flags.FLAGS


48
49
50
51
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False):
52
  """Run bert squad training."""
Chen Chen's avatar
Chen Chen committed
53
54
55
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
                               custom_callbacks, run_eagerly)
56
57
58


def predict_squad(strategy, input_meta_data):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
  """Makes predictions for the squad dataset."""
Chen Chen's avatar
Chen Chen committed
60
61
62
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
63
64
65
66
67
68
69
70
71
72
73
74
  run_squad_helper.predict_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)


def eval_squad(strategy, input_meta_data):
  """Evaluate on the squad dataset."""
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  eval_metrics = run_squad_helper.eval_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
  return eval_metrics
75
76


Hongkun Yu's avatar
Hongkun Yu committed
77
78
79
80
81
82
83
84
85
86
def export_squad(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
Chen Chen's avatar
Chen Chen committed
87
88
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
Hongkun Yu's avatar
Hongkun Yu committed
89
90


91
92
def main(_):
  # Users should always run this script under TF 2.x
93

94
95
96
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

Hongkun Yu's avatar
Hongkun Yu committed
97
98
99
100
  if FLAGS.mode == 'export_only':
    export_squad(FLAGS.model_export_path, input_meta_data)
    return

101
102
103
104
  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
    _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
                                             FLAGS.task_index)
105
106
107
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
108
      all_reduce_alg=FLAGS.all_reduce_alg,
109
      tpu_address=FLAGS.tpu)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
110
111

  if 'train' in FLAGS.mode:
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    if FLAGS.log_steps:
      custom_callbacks = [keras_utils.TimeHistory(
          batch_size=FLAGS.train_batch_size,
          log_steps=FLAGS.log_steps,
          logdir=FLAGS.model_dir,
      )]
    else:
      custom_callbacks = None

    train_squad(
        strategy,
        input_meta_data,
        custom_callbacks=custom_callbacks,
        run_eagerly=FLAGS.run_eagerly,
    )
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
  if 'predict' in FLAGS.mode:
128
    predict_squad(strategy, input_meta_data)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
  if 'eval' in FLAGS.mode:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
132
133
134
    eval_metrics = eval_squad(strategy, input_meta_data)
    f1_score = eval_metrics['final_f1']
    logging.info('SQuAD eval F1-score: %f', f1_score)
    if (not strategy) or strategy.extended.should_save_summary:
      summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
    else:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
136
137
138
139
140
141
142
143
144
145
      summary_dir = tempfile.mkdtemp()
    summary_writer = tf.summary.create_file_writer(
        os.path.join(summary_dir, 'eval'))
    with summary_writer.as_default():
      # TODO(lehou): write to the correct step number.
      tf.summary.scalar('F1-score', f1_score, step=0)
      summary_writer.flush()
    # Wait for some time, for the depending mldash/tensorboard jobs to finish
    # exporting the final F1-score.
    time.sleep(60)
146
147
148
149
150
151


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('model_dir')
  app.run(main)