create_finetuning_data.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""

17
import functools
18
import json
19
import os
20

Hongkun Yu's avatar
Hongkun Yu committed
21
# Import libraries
22
23
24
from absl import app
from absl import flags
import tensorflow as tf
25
26
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
from official.nlp.data import sentence_retrieval_lib
28
# word-piece tokenizer based squad_lib
29
from official.nlp.data import squad_lib as squad_lib_wp
30
# sentence-piece tokenizer based squad_lib
31
from official.nlp.data import squad_lib_sp
32
from official.nlp.data import tagging_data_lib
33
34
35

FLAGS = flags.FLAGS

36
# TODO(chendouble): consider moving each task to its own binary.
37
flags.DEFINE_enum(
Maxim Neumann's avatar
Maxim Neumann committed
38
    "fine_tuning_task_type", "classification",
39
    ["classification", "regression", "squad", "retrieval", "tagging"],
40
    "The name of the BERT fine tuning task for which data "
41
    "will be generated.")
42

43
# BERT classification specific flags.
44
45
46
47
48
flags.DEFINE_string(
    "input_data_dir", None,
    "The input data dir. Should contain the .tsv files (or other data files) "
    "for the task.")

49
50
51
flags.DEFINE_enum(
    "classification_task_name", "MNLI", [
        "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
stephenwu's avatar
stephenwu committed
52
        "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g",
stephenwu's avatar
stephenwu committed
53
        "RTE-SuperGLUE"
54
55
56
57
58
    ], "The name of the task to train BERT classifier. The "
    "difference between XTREME-XNLI and XNLI is: 1. the format "
    "of input tsv files; 2. the dev set for XTREME is english "
    "only and for XNLI is all languages combined. Same for "
    "PAWS-X.")
59

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
60
# MNLI task-specific flag.
61
62
flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
                  "The type of MNLI dataset.")
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
63
64

# XNLI task-specific flag.
Tianqi Liu's avatar
Tianqi Liu committed
65
66
flags.DEFINE_string(
    "xnli_language", "en",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
67
    "Language of training data for XNLI task. If the value is 'all', the data "
Tianqi Liu's avatar
Tianqi Liu committed
68
69
    "of all languages will be used for training.")

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
# PAWS-X task-specific flag.
Tianqi Liu's avatar
Tianqi Liu committed
71
72
flags.DEFINE_string(
    "pawsx_language", "en",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
73
    "Language of training data for PAWS-X task. If the value is 'all', the data "
Tianqi Liu's avatar
Tianqi Liu committed
74
    "of all languages will be used for training.")
Tianqi Liu's avatar
Tianqi Liu committed
75

76
77
78
79
80
81
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags.DEFINE_string(
    "translated_input_data_dir", None,
    "The translated input data dir. Should contain the .tsv files (or other "
    "data files) for the task.")

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
82
# Retrieval task-specific flags.
83
84
85
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
                  "The name of sentence retrieval task for scoring")

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
86
# Tagging task-specific flags.
87
88
89
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
                  "The name of BERT tagging (token classification) task.")

90
91
92
flags.DEFINE_bool("tagging_only_use_en_train", True,
                  "Whether only use english training data in tagging.")

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
# BERT Squad task-specific flags.
94
95
96
97
flags.DEFINE_string(
    "squad_data_file", None,
    "The input data file in for generating training data for BERT squad task.")

98
99
100
101
102
flags.DEFINE_string(
    "translated_squad_data_folder", None,
    "The translated data folder for generating training data for BERT squad "
    "task.")

103
104
105
106
107
108
109
110
111
112
flags.DEFINE_integer(
    "doc_stride", 128,
    "When splitting up a long document into chunks, how much stride to "
    "take between chunks.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

113
114
115
116
flags.DEFINE_bool(
    "version_2_with_negative", False,
    "If true, the SQuAD examples contain some that do not have an answer.")

117
118
119
120
121
flags.DEFINE_bool(
    "xlnet_format", False,
    "If true, then data will be preprocessed in a paragraph, query, class order"
    " instead of the BERT-style class, paragraph, query order.")

122
123
124
# XTREME specific flags.
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")

125
126
127
128
129
130
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "train_data_output_path", None,
131
    "The path in which generated training input data will be written as tf"
132
    " records.")
133
134
135

flags.DEFINE_string(
    "eval_data_output_path", None,
Tianqi Liu's avatar
Tianqi Liu committed
136
    "The path in which generated evaluation input data will be written as tf"
137
    " records.")
138

Tianqi Liu's avatar
Tianqi Liu committed
139
140
141
flags.DEFINE_string(
    "test_data_output_path", None,
    "The path in which generated test input data will be written as tf"
Tianqi Liu's avatar
Tianqi Liu committed
142
143
    " records. If None, do not generate test data. Must be a pattern template"
    " as test_{}.tfrecords if processor has language specific test data.")
Tianqi Liu's avatar
Tianqi Liu committed
144

145
146
147
148
149
150
151
152
153
154
155
156
157
158
flags.DEFINE_string("meta_data_file_path", None,
                    "The path in which input meta data will be written.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

159
160
161
162
flags.DEFINE_string("sp_model_file", "",
                    "The path to the model used by sentence piece tokenizer.")

flags.DEFINE_enum(
Chen Chen's avatar
Chen Chen committed
163
164
165
166
    "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
    "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
    "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
    "while ALBERT uses SentencePiece tokenizer.")
167

168
169
170
171
flags.DEFINE_string(
    "tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
    "generic classfication data import (for more details "
    "see the TfdsProcessor class documentation).")
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
172

173
174
175

def generate_classifier_dataset():
  """Generates classifier dataset and returns input meta data."""
176
177
  assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
          FLAGS.tfds_params)
178

Chen Chen's avatar
Chen Chen committed
179
  if FLAGS.tokenization == "WordPiece":
180
181
182
183
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
  else:
Chen Chen's avatar
Chen Chen committed
184
    assert FLAGS.tokenization == "SentencePiece"
185
186
187
188
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
189
190
  if FLAGS.tfds_params:
    processor = classifier_data_lib.TfdsProcessor(
191
        tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
193
194
195
196
197
    return classifier_data_lib.generate_tf_record_from_data_file(
        processor,
        None,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
Tianqi Liu's avatar
Tianqi Liu committed
198
        test_data_output_path=FLAGS.test_data_output_path,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
199
200
201
        max_seq_length=FLAGS.max_seq_length)
  else:
    processors = {
Vincent Etter's avatar
Vincent Etter committed
202
203
        "ax":
            classifier_data_lib.AxProcessor,
Tianqi Liu's avatar
Tianqi Liu committed
204
205
        "cola":
            classifier_data_lib.ColaProcessor,
206
207
        "imdb":
            classifier_data_lib.ImdbProcessor,
Tianqi Liu's avatar
Tianqi Liu committed
208
        "mnli":
209
210
            functools.partial(
                classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
Tianqi Liu's avatar
Tianqi Liu committed
211
212
213
214
        "mrpc":
            classifier_data_lib.MrpcProcessor,
        "qnli":
            classifier_data_lib.QnliProcessor,
215
216
217
218
        "qqp":
            classifier_data_lib.QqpProcessor,
        "rte":
            classifier_data_lib.RteProcessor,
Tianqi Liu's avatar
Tianqi Liu committed
219
220
        "sst-2":
            classifier_data_lib.SstProcessor,
221
222
        "sts-b":
            classifier_data_lib.StsBProcessor,
Tianqi Liu's avatar
Tianqi Liu committed
223
        "xnli":
224
225
226
            functools.partial(
                classifier_data_lib.XnliProcessor,
                language=FLAGS.xnli_language),
Tianqi Liu's avatar
Tianqi Liu committed
227
        "paws-x":
228
229
230
231
232
            functools.partial(
                classifier_data_lib.PawsxProcessor,
                language=FLAGS.pawsx_language),
        "wnli":
            classifier_data_lib.WnliProcessor,
Tianqi Liu's avatar
Tianqi Liu committed
233
        "xtreme-xnli":
234
235
236
237
            functools.partial(
                classifier_data_lib.XtremeXnliProcessor,
                translated_data_dir=FLAGS.translated_input_data_dir,
                only_use_en_dev=FLAGS.only_use_en_dev),
Tianqi Liu's avatar
Tianqi Liu committed
238
        "xtreme-paws-x":
239
240
241
            functools.partial(
                classifier_data_lib.XtremePawsxProcessor,
                translated_data_dir=FLAGS.translated_input_data_dir,
stephenwu's avatar
stephenwu committed
242
243
                only_use_en_dev=FLAGS.only_use_en_dev),
        "ax-g":
stephenwu's avatar
stephenwu committed
244
245
246
            classifier_data_lib.AXgProcessor,
        "rte-superglue":
            classifier_data_lib.RTESuperGLUEProcessor
stephenwu's avatar
stephenwu committed
247

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
248
249
250
251
252
    }
    task_name = FLAGS.classification_task_name.lower()
    if task_name not in processors:
      raise ValueError("Task not found: %s" % (task_name))

Tianqi Liu's avatar
Tianqi Liu committed
253
    processor = processors[task_name](process_text_fn=processor_text_fn)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
254
255
256
257
258
259
    return classifier_data_lib.generate_tf_record_from_data_file(
        processor,
        FLAGS.input_data_dir,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
Tianqi Liu's avatar
Tianqi Liu committed
260
        test_data_output_path=FLAGS.test_data_output_path,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
261
        max_seq_length=FLAGS.max_seq_length)
262
263


Maxim Neumann's avatar
Maxim Neumann committed
264
265
def generate_regression_dataset():
  """Generates regression dataset and returns input meta data."""
Chen Chen's avatar
Chen Chen committed
266
  if FLAGS.tokenization == "WordPiece":
Maxim Neumann's avatar
Maxim Neumann committed
267
268
269
270
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
  else:
Chen Chen's avatar
Chen Chen committed
271
    assert FLAGS.tokenization == "SentencePiece"
Maxim Neumann's avatar
Maxim Neumann committed
272
273
274
275
276
277
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)

  if FLAGS.tfds_params:
    processor = classifier_data_lib.TfdsProcessor(
278
        tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
Maxim Neumann's avatar
Maxim Neumann committed
279
280
281
282
283
284
285
286
287
288
289
290
    return classifier_data_lib.generate_tf_record_from_data_file(
        processor,
        None,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
        test_data_output_path=FLAGS.test_data_output_path,
        max_seq_length=FLAGS.max_seq_length)
  else:
    raise ValueError("No data processor found for the given regression task.")


291
292
293
def generate_squad_dataset():
  """Generates squad training dataset and returns input meta data."""
  assert FLAGS.squad_data_file
Chen Chen's avatar
Chen Chen committed
294
  if FLAGS.tokenization == "WordPiece":
295
    return squad_lib_wp.generate_tf_record_from_json_file(
Allen Wang's avatar
Allen Wang committed
296
297
298
        input_file_path=FLAGS.squad_data_file,
        vocab_file_path=FLAGS.vocab_file,
        output_path=FLAGS.train_data_output_path,
299
        translated_input_folder=FLAGS.translated_squad_data_folder,
Allen Wang's avatar
Allen Wang committed
300
301
302
303
304
305
        max_seq_length=FLAGS.max_seq_length,
        do_lower_case=FLAGS.do_lower_case,
        max_query_length=FLAGS.max_query_length,
        doc_stride=FLAGS.doc_stride,
        version_2_with_negative=FLAGS.version_2_with_negative,
        xlnet_format=FLAGS.xlnet_format)
306
  else:
Chen Chen's avatar
Chen Chen committed
307
    assert FLAGS.tokenization == "SentencePiece"
308
    return squad_lib_sp.generate_tf_record_from_json_file(
309
310
311
        input_file_path=FLAGS.squad_data_file,
        sp_model_file=FLAGS.sp_model_file,
        output_path=FLAGS.train_data_output_path,
312
        translated_input_folder=FLAGS.translated_squad_data_folder,
313
314
315
316
317
318
        max_seq_length=FLAGS.max_seq_length,
        do_lower_case=FLAGS.do_lower_case,
        max_query_length=FLAGS.max_query_length,
        doc_stride=FLAGS.doc_stride,
        xlnet_format=FLAGS.xlnet_format,
        version_2_with_negative=FLAGS.version_2_with_negative)
319
320


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
321
322
323
def generate_retrieval_dataset():
  """Generate retrieval test and dev dataset and returns input meta data."""
  assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
Chen Chen's avatar
Chen Chen committed
324
  if FLAGS.tokenization == "WordPiece":
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
325
326
327
328
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
  else:
Chen Chen's avatar
Chen Chen committed
329
    assert FLAGS.tokenization == "SentencePiece"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)

  processors = {
      "bucc": sentence_retrieval_lib.BuccProcessor,
      "tatoeba": sentence_retrieval_lib.TatoebaProcessor,
  }

  task_name = FLAGS.retrieval_task_name.lower()
  if task_name not in processors:
    raise ValueError("Task not found: %s" % task_name)

  processor = processors[task_name](process_text_fn=processor_text_fn)

  return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
346
347
      processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
      FLAGS.test_data_output_path, FLAGS.max_seq_length)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
348
349


350
351
352
def generate_tagging_dataset():
  """Generates tagging dataset."""
  processors = {
353
354
355
356
357
358
359
360
361
362
      "panx":
          functools.partial(
              tagging_data_lib.PanxProcessor,
              only_use_en_train=FLAGS.tagging_only_use_en_train,
              only_use_en_dev=FLAGS.only_use_en_dev),
      "udpos":
          functools.partial(
              tagging_data_lib.UdposProcessor,
              only_use_en_train=FLAGS.tagging_only_use_en_train,
              only_use_en_dev=FLAGS.only_use_en_dev),
363
364
365
366
367
  }
  task_name = FLAGS.tagging_task_name.lower()
  if task_name not in processors:
    raise ValueError("Task not found: %s" % task_name)

Chen Chen's avatar
Chen Chen committed
368
  if FLAGS.tokenization == "WordPiece":
369
370
371
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
Chen Chen's avatar
Chen Chen committed
372
  elif FLAGS.tokenization == "SentencePiece":
373
374
375
376
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)
  else:
Chen Chen's avatar
Chen Chen committed
377
    raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
378
379
380
381
382
383
384
385

  processor = processors[task_name]()
  return tagging_data_lib.generate_tf_record_from_data_file(
      processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
      FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
      FLAGS.test_data_output_path, processor_text_fn)


386
def main(_):
Chen Chen's avatar
Chen Chen committed
387
  if FLAGS.tokenization == "WordPiece":
388
389
390
391
    if not FLAGS.vocab_file:
      raise ValueError(
          "FLAG vocab_file for word-piece tokenizer is not specified.")
  else:
Chen Chen's avatar
Chen Chen committed
392
    assert FLAGS.tokenization == "SentencePiece"
393
394
395
396
    if not FLAGS.sp_model_file:
      raise ValueError(
          "FLAG sp_model_file for sentence-piece tokenizer is not specified.")

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
397
398
399
  if FLAGS.fine_tuning_task_type != "retrieval":
    flags.mark_flag_as_required("train_data_output_path")

400
401
  if FLAGS.fine_tuning_task_type == "classification":
    input_meta_data = generate_classifier_dataset()
Maxim Neumann's avatar
Maxim Neumann committed
402
403
  elif FLAGS.fine_tuning_task_type == "regression":
    input_meta_data = generate_regression_dataset()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
404
405
  elif FLAGS.fine_tuning_task_type == "retrieval":
    input_meta_data = generate_retrieval_dataset()
406
  elif FLAGS.fine_tuning_task_type == "squad":
407
    input_meta_data = generate_squad_dataset()
408
409
410
  else:
    assert FLAGS.fine_tuning_task_type == "tagging"
    input_meta_data = generate_tagging_dataset()
411

412
  tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
413
414
415
416
417
418
419
  with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
    writer.write(json.dumps(input_meta_data, indent=4) + "\n")


if __name__ == "__main__":
  flags.mark_flag_as_required("meta_data_file_path")
  app.run(main)