Commit e0d33132 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309331866
parent 3d8eaa11
......@@ -318,7 +318,7 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=400):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
......
......@@ -107,21 +107,36 @@ class DataProcessor(object):
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh"
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
lines = []
for language in self.languages:
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language)))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
......@@ -138,10 +153,7 @@ class XnliProcessor(DataProcessor):
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = self.process_text_fn(line[0])
if language != self.process_text_fn(self.language):
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
......@@ -149,6 +161,22 @@ class XnliProcessor(DataProcessor):
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
......@@ -678,6 +706,7 @@ def generate_tf_record_from_data_file(processor,
tokenizer,
train_data_output_path=None,
eval_data_output_path=None,
test_data_output_path=None,
max_seq_length=128):
"""Generates and saves training data into a tf record file.
......@@ -691,6 +720,8 @@ def generate_tf_record_from_data_file(processor,
will be saved.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor is XNLI.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
......@@ -713,6 +744,19 @@ def generate_tf_record_from_data_file(processor,
label_list, max_seq_length,
tokenizer, eval_data_output_path)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features(
examples,
label_list, max_seq_length,
tokenizer, test_data_output_path.format(language))
else:
file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length,
tokenizer, test_data_output_path)
meta_data = {
"task_type": "bert_classification",
"processor_type": processor.get_processor_name(),
......@@ -724,4 +768,12 @@ def generate_tf_record_from_data_file(processor,
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
meta_data["test_data_size"] = len(test_input_data_examples)
return meta_data
......@@ -49,6 +49,12 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI"],
"The name of the task to train BERT classifier.")
# XNLI task specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training and evaluation data for XNIL task. If the value is "
"'all', the data of all languages will be used for training.")
# BERT Squad task specific flags.
flags.DEFINE_string(
"squad_data_file", None,
......@@ -79,9 +85,14 @@ flags.DEFINE_string(
flags.DEFINE_string(
"eval_data_output_path", None,
"The path in which generated training input data will be written as tf"
"The path in which generated evaluation input data will be written as tf"
" records.")
flags.DEFINE_string(
"test_data_output_path", None,
"The path in which generated test input data will be written as tf"
" records. If None, do not generate test data.")
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
......@@ -136,28 +147,37 @@ def generate_classifier_dataset():
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length)
else:
processors = {
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor,
"cola":
classifier_data_lib.ColaProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name](processor_text_fn)
processor = processors[task_name](process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
test_data_output_path=FLAGS.test_data_output_path,
max_seq_length=FLAGS.max_seq_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment