Commit e0d33132 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309331866
parent 3d8eaa11
...@@ -318,7 +318,7 @@ class BasicTokenizer(object): ...@@ -318,7 +318,7 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object): class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation.""" """Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=400):
self.vocab = vocab self.vocab = vocab
self.unk_token = unk_token self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word self.max_input_chars_per_word = max_input_chars_per_word
......
...@@ -107,21 +107,36 @@ class DataProcessor(object): ...@@ -107,21 +107,36 @@ class DataProcessor(object):
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self, process_text_fn=tokenization.convert_to_unicode): def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn) super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh" if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv( lines = []
for language in self.languages:
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli", os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language)) "multinli.train.%s.tsv" % language)))
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "train-%d" % (i) guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2]) label = self.process_text_fn(line[2])
...@@ -138,10 +153,7 @@ class XnliProcessor(DataProcessor): ...@@ -138,10 +153,7 @@ class XnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % (i) guid = "dev-%d" % i
language = self.process_text_fn(line[0])
if language != self.process_text_fn(self.language):
continue
text_a = self.process_text_fn(line[6]) text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7]) text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1]) label = self.process_text_fn(line[1])
...@@ -149,6 +161,22 @@ class XnliProcessor(DataProcessor): ...@@ -149,6 +161,22 @@ class XnliProcessor(DataProcessor):
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["contradiction", "entailment", "neutral"]
...@@ -678,6 +706,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -678,6 +706,7 @@ def generate_tf_record_from_data_file(processor,
tokenizer, tokenizer,
train_data_output_path=None, train_data_output_path=None,
eval_data_output_path=None, eval_data_output_path=None,
test_data_output_path=None,
max_seq_length=128): max_seq_length=128):
"""Generates and saves training data into a tf record file. """Generates and saves training data into a tf record file.
...@@ -691,6 +720,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -691,6 +720,8 @@ def generate_tf_record_from_data_file(processor,
will be saved. will be saved.
eval_data_output_path: Output to which processed tf record for evaluation eval_data_output_path: Output to which processed tf record for evaluation
will be saved. will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor is XNLI.
max_seq_length: Maximum sequence length of the to be generated max_seq_length: Maximum sequence length of the to be generated
training/eval data. training/eval data.
...@@ -713,6 +744,19 @@ def generate_tf_record_from_data_file(processor, ...@@ -713,6 +744,19 @@ def generate_tf_record_from_data_file(processor,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, eval_data_output_path) tokenizer, eval_data_output_path)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features(
examples,
label_list, max_seq_length,
tokenizer, test_data_output_path.format(language))
else:
file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length,
tokenizer, test_data_output_path)
meta_data = { meta_data = {
"task_type": "bert_classification", "task_type": "bert_classification",
"processor_type": processor.get_processor_name(), "processor_type": processor.get_processor_name(),
...@@ -724,4 +768,12 @@ def generate_tf_record_from_data_file(processor, ...@@ -724,4 +768,12 @@ def generate_tf_record_from_data_file(processor,
if eval_data_output_path: if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples) meta_data["eval_data_size"] = len(eval_input_data_examples)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
meta_data["test_data_size"] = len(test_input_data_examples)
return meta_data return meta_data
...@@ -49,6 +49,12 @@ flags.DEFINE_enum("classification_task_name", "MNLI", ...@@ -49,6 +49,12 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI"], ["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier.")
# XNLI task specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training and evaluation data for XNIL task. If the value is "
"'all', the data of all languages will be used for training.")
# BERT Squad task specific flags. # BERT Squad task specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
...@@ -79,9 +85,14 @@ flags.DEFINE_string( ...@@ -79,9 +85,14 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
"eval_data_output_path", None, "eval_data_output_path", None,
"The path in which generated training input data will be written as tf" "The path in which generated evaluation input data will be written as tf"
" records.") " records.")
flags.DEFINE_string(
"test_data_output_path", None,
"The path in which generated test input data will be written as tf"
" records. If None, do not generate test data.")
flags.DEFINE_string("meta_data_file_path", None, flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.") "The path in which input meta data will be written.")
...@@ -136,28 +147,37 @@ def generate_classifier_dataset(): ...@@ -136,28 +147,37 @@ def generate_classifier_dataset():
tokenizer, tokenizer,
train_data_output_path=FLAGS.train_data_output_path, train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path, eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length) max_seq_length=FLAGS.max_seq_length)
else: else:
processors = { processors = {
"cola": classifier_data_lib.ColaProcessor, "cola":
"mnli": classifier_data_lib.MnliProcessor, classifier_data_lib.ColaProcessor,
"mrpc": classifier_data_lib.MrpcProcessor, "mnli":
"qnli": classifier_data_lib.QnliProcessor, classifier_data_lib.MnliProcessor,
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor, "qqp": classifier_data_lib.QqpProcessor,
"sst-2": classifier_data_lib.SstProcessor, "sst-2":
"xnli": classifier_data_lib.XnliProcessor, classifier_data_lib.SstProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name)) raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name](processor_text_fn) processor = processors[task_name](process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file( return classifier_data_lib.generate_tf_record_from_data_file(
processor, processor,
FLAGS.input_data_dir, FLAGS.input_data_dir,
tokenizer, tokenizer,
train_data_output_path=FLAGS.train_data_output_path, train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path, eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length) max_seq_length=FLAGS.max_seq_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment