run_squad.py 5.23 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
16

17
import json
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
18
import os
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
20
import time

Hongkun Yu's avatar
Hongkun Yu committed
21
# Import libraries
22
23
from absl import app
from absl import flags
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
from absl import logging
Le Hou's avatar
Le Hou committed
25
import gin
26
import tensorflow as tf
27
from official.common import distribute_utils
28
from official.nlp.bert import configs as bert_configs
Chen Chen's avatar
Chen Chen committed
29
from official.nlp.bert import run_squad_helper
30
from official.nlp.bert import tokenization
31
from official.nlp.data import squad_lib as squad_lib_wp
32
from official.utils.misc import keras_utils
33

Chen Chen's avatar
Chen Chen committed
34

35
36
37
flags.DEFINE_string('vocab_file', None,
                    'The vocabulary file that the BERT model was trained on.')

Chen Chen's avatar
Chen Chen committed
38
39
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
40

41
42
43
FLAGS = flags.FLAGS


44
45
46
def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
47
                run_eagerly=False,
48
49
                init_checkpoint=None,
                sub_model_export_name=None):
50
  """Run bert squad training."""
Chen Chen's avatar
Chen Chen committed
51
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
  init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
Chen Chen's avatar
Chen Chen committed
53
  run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
54
55
                               custom_callbacks, run_eagerly, init_checkpoint,
                               sub_model_export_name=sub_model_export_name)
56
57
58


def predict_squad(strategy, input_meta_data):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
  """Makes predictions for the squad dataset."""
Chen Chen's avatar
Chen Chen committed
60
61
62
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
63
64
65
66
67
68
69
70
71
72
73
74
  run_squad_helper.predict_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)


def eval_squad(strategy, input_meta_data):
  """Evaluate on the squad dataset."""
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  eval_metrics = run_squad_helper.eval_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
  return eval_metrics
75
76


Hongkun Yu's avatar
Hongkun Yu committed
77
78
79
80
81
82
83
84
85
86
def export_squad(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
Chen Chen's avatar
Chen Chen committed
87
88
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
Hongkun Yu's avatar
Hongkun Yu committed
89
90


91
def main(_):
Le Hou's avatar
Le Hou committed
92
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
93

94
95
96
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

Hongkun Yu's avatar
Hongkun Yu committed
97
98
99
100
  if FLAGS.mode == 'export_only':
    export_squad(FLAGS.model_export_path, input_meta_data)
    return

101
102
  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
103
104
    _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
  strategy = distribute_utils.get_distribution_strategy(
105
106
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
107
      all_reduce_alg=FLAGS.all_reduce_alg,
108
      tpu_address=FLAGS.tpu)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
110

  if 'train' in FLAGS.mode:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    if FLAGS.log_steps:
      custom_callbacks = [keras_utils.TimeHistory(
          batch_size=FLAGS.train_batch_size,
          log_steps=FLAGS.log_steps,
          logdir=FLAGS.model_dir,
      )]
    else:
      custom_callbacks = None

    train_squad(
        strategy,
        input_meta_data,
        custom_callbacks=custom_callbacks,
        run_eagerly=FLAGS.run_eagerly,
125
        sub_model_export_name=FLAGS.sub_model_export_name,
126
    )
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
  if 'predict' in FLAGS.mode:
128
    predict_squad(strategy, input_meta_data)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
  if 'eval' in FLAGS.mode:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
132
    eval_metrics = eval_squad(strategy, input_meta_data)
    f1_score = eval_metrics['final_f1']
    logging.info('SQuAD eval F1-score: %f', f1_score)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
133
134
    summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
    summary_writer = tf.summary.create_file_writer(summary_dir)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
136
137
138
    with summary_writer.as_default():
      # TODO(lehou): write to the correct step number.
      tf.summary.scalar('F1-score', f1_score, step=0)
      summary_writer.flush()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
139
140
141
    # Also write eval_metrics to json file.
    squad_lib_wp.write_to_json_files(
        eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
142
    time.sleep(60)
143
144
145
146
147
148


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('model_dir')
  app.run(main)