Commit 2d08d875 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Add supports of using translated data in XTREME benchmarks.

PiperOrigin-RevId: 342448071
parent 8dace71e
...@@ -938,45 +938,104 @@ class XtremePawsxProcessor(DataProcessor): ...@@ -938,45 +938,104 @@ class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set.""" """Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"] supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training and testing data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremePawsxProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for i, line in enumerate(lines): if self.translated_data_dir is None:
guid = "train-%d" % i lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
text_a = self.process_text_fn(line[0]) for i, line in enumerate(lines):
text_b = self.process_text_fn(line[1]) guid = "train-%d" % i
label = self.process_text_fn(line[2]) text_a = self.process_text_fn(line[0])
examples.append( text_b = self.process_text_fn(line[1])
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for i, line in enumerate(lines): if self.only_use_en_dev:
guid = "dev-%d" % i lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
text_a = self.process_text_fn(line[0]) for i, line in enumerate(lines):
text_b = self.process_text_fn(line[1]) guid = "dev-%d" % i
label = self.process_text_fn(line[2]) text_a = self.process_text_fn(line[0])
examples.append( text_b = self.process_text_fn(line[1])
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {}
for lang in self.supported_languages: for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines): for i, line in enumerate(lines):
guid = "test-%d" % i guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
label = "0" label = "0"
examples_by_lang[lang].append( examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "0"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang return examples_by_lang
def get_labels(self): def get_labels(self):
...@@ -996,45 +1055,111 @@ class XtremeXnliProcessor(DataProcessor): ...@@ -996,45 +1055,111 @@ class XtremeXnliProcessor(DataProcessor):
"ur", "vi", "zh" "ur", "vi", "zh"
] ]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremeXnliProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = [] examples = []
for i, line in enumerate(lines): if self.translated_data_dir is None:
guid = "train-%d" % i for i, line in enumerate(lines):
text_a = self.process_text_fn(line[0]) guid = "train-%d" % i
text_b = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[0])
label = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[1])
examples.append( label = self.process_text_fn(line[2])
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = [] examples = []
for i, line in enumerate(lines): if self.only_use_en_dev:
guid = "dev-%d" % i lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
text_a = self.process_text_fn(line[0]) for i, line in enumerate(lines):
text_b = self.process_text_fn(line[1]) guid = "dev-%d" % i
label = self.process_text_fn(line[2]) text_a = self.process_text_fn(line[0])
examples.append( text_b = self.process_text_fn(line[1])
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in self.supported_languages} examples_by_lang = {}
for lang in self.supported_languages: for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines): for i, line in enumerate(lines):
guid = f"test-{i}" guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
label = "contradiction" label = "contradiction"
examples_by_lang[lang].append( examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "contradiction"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang return examples_by_lang
def get_labels(self): def get_labels(self):
......
...@@ -46,20 +46,19 @@ flags.DEFINE_string( ...@@ -46,20 +46,19 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) " "The input data dir. Should contain the .tsv files (or other data files) "
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum(
["AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "classification_task_name", "MNLI", [
"QQP", "RTE", "SST-2", "STS-B", "WNLI", "XNLI", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"XTREME-XNLI", "XTREME-PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X"
"The name of the task to train BERT classifier. The " ], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for " "only and for XNLI is all languages combined. Same for "
"PAWS-X.") "PAWS-X.")
# MNLI task-specific flag. # MNLI task-specific flag.
flags.DEFINE_enum( flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
"mnli_type", "matched", ["matched", "mismatched"], "The type of MNLI dataset.")
"The type of MNLI dataset.")
# XNLI task-specific flag. # XNLI task-specific flag.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -73,6 +72,12 @@ flags.DEFINE_string( ...@@ -73,6 +72,12 @@ flags.DEFINE_string(
"Language of training data for PAWS-X task. If the value is 'all', the data " "Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags.DEFINE_string(
"translated_input_data_dir", None,
"The translated input data dir. Should contain the .tsv files (or other "
"data files) for the task.")
# Retrieval task-specific flags. # Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring") "The name of sentence retrieval task for scoring")
...@@ -81,11 +86,19 @@ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], ...@@ -81,11 +86,19 @@ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"], flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.") "The name of BERT tagging (token classification) task.")
flags.DEFINE_bool("tagging_only_use_en_train", True,
"Whether only use english training data in tagging.")
# BERT Squad task-specific flags. # BERT Squad task-specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
"The input data file in for generating training data for BERT squad task.") "The input data file in for generating training data for BERT squad task.")
flags.DEFINE_string(
"translated_squad_data_folder", None,
"The translated data folder for generating training data for BERT squad "
"task.")
flags.DEFINE_integer( flags.DEFINE_integer(
"doc_stride", 128, "doc_stride", 128,
"When splitting up a long document into chunks, how much stride to " "When splitting up a long document into chunks, how much stride to "
...@@ -105,6 +118,9 @@ flags.DEFINE_bool( ...@@ -105,6 +118,9 @@ flags.DEFINE_bool(
"If true, then data will be preprocessed in a paragraph, query, class order" "If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order.") " instead of the BERT-style class, paragraph, query order.")
# XTREME specific flags.
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
# Shared flags across BERT fine-tuning tasks. # Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
...@@ -148,16 +164,16 @@ flags.DEFINE_enum( ...@@ -148,16 +164,16 @@ flags.DEFINE_enum(
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, " "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses SentencePiece tokenizer.") "while ALBERT uses SentencePiece tokenizer.")
flags.DEFINE_string("tfds_params", "", flags.DEFINE_string(
"Comma-separated list of TFDS parameter assigments for " "tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details " "generic classfication data import (for more details "
"see the TfdsProcessor class documentation).") "see the TfdsProcessor class documentation).")
def generate_classifier_dataset(): def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data.""" """Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
or FLAGS.tfds_params) FLAGS.tfds_params)
if FLAGS.tokenization == "WordPiece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
...@@ -171,8 +187,7 @@ def generate_classifier_dataset(): ...@@ -171,8 +187,7 @@ def generate_classifier_dataset():
if FLAGS.tfds_params: if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor( processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params, tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file( return classifier_data_lib.generate_tf_record_from_data_file(
processor, processor,
None, None,
...@@ -190,29 +205,40 @@ def generate_classifier_dataset(): ...@@ -190,29 +205,40 @@ def generate_classifier_dataset():
"imdb": "imdb":
classifier_data_lib.ImdbProcessor, classifier_data_lib.ImdbProcessor,
"mnli": "mnli":
functools.partial(classifier_data_lib.MnliProcessor, functools.partial(
mnli_type=FLAGS.mnli_type), classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
"mrpc": "mrpc":
classifier_data_lib.MrpcProcessor, classifier_data_lib.MrpcProcessor,
"qnli": "qnli":
classifier_data_lib.QnliProcessor, classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor, "qqp":
"rte": classifier_data_lib.RteProcessor, classifier_data_lib.QqpProcessor,
"rte":
classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b": "sts-b":
classifier_data_lib.StsBProcessor, classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(
language=FLAGS.xnli_language), classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
"paws-x": "paws-x":
functools.partial(classifier_data_lib.PawsxProcessor, functools.partial(
language=FLAGS.pawsx_language), classifier_data_lib.PawsxProcessor,
"wnli": classifier_data_lib.WnliProcessor, language=FLAGS.pawsx_language),
"wnli":
classifier_data_lib.WnliProcessor,
"xtreme-xnli": "xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor), functools.partial(
classifier_data_lib.XtremeXnliProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev),
"xtreme-paws-x": "xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor) functools.partial(
classifier_data_lib.XtremePawsxProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev)
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
...@@ -243,8 +269,7 @@ def generate_regression_dataset(): ...@@ -243,8 +269,7 @@ def generate_regression_dataset():
if FLAGS.tfds_params: if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor( processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params, tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file( return classifier_data_lib.generate_tf_record_from_data_file(
processor, processor,
None, None,
...@@ -265,6 +290,7 @@ def generate_squad_dataset(): ...@@ -265,6 +290,7 @@ def generate_squad_dataset():
input_file_path=FLAGS.squad_data_file, input_file_path=FLAGS.squad_data_file,
vocab_file_path=FLAGS.vocab_file, vocab_file_path=FLAGS.vocab_file,
output_path=FLAGS.train_data_output_path, output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length, max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case, do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length, max_query_length=FLAGS.max_query_length,
...@@ -277,6 +303,7 @@ def generate_squad_dataset(): ...@@ -277,6 +303,7 @@ def generate_squad_dataset():
input_file_path=FLAGS.squad_data_file, input_file_path=FLAGS.squad_data_file,
sp_model_file=FLAGS.sp_model_file, sp_model_file=FLAGS.sp_model_file,
output_path=FLAGS.train_data_output_path, output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length, max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case, do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length, max_query_length=FLAGS.max_query_length,
...@@ -310,19 +337,23 @@ def generate_retrieval_dataset(): ...@@ -310,19 +337,23 @@ def generate_retrieval_dataset():
processor = processors[task_name](process_text_fn=processor_text_fn) processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record( return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor, processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
FLAGS.input_data_dir, FLAGS.test_data_output_path, FLAGS.max_seq_length)
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def generate_tagging_dataset(): def generate_tagging_dataset():
"""Generates tagging dataset.""" """Generates tagging dataset."""
processors = { processors = {
"panx": tagging_data_lib.PanxProcessor, "panx":
"udpos": tagging_data_lib.UdposProcessor, functools.partial(
tagging_data_lib.PanxProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
"udpos":
functools.partial(
tagging_data_lib.UdposProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
} }
task_name = FLAGS.tagging_task_name.lower() task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors: if task_name not in processors:
......
...@@ -158,11 +158,20 @@ class FeatureWriter(object): ...@@ -158,11 +158,20 @@ class FeatureWriter(object):
self._writer.close() self._writer.close()
def read_squad_examples(input_file, is_training, version_2_with_negative): def read_squad_examples(input_file, is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader: with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
def is_whitespace(c): def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True return True
...@@ -930,6 +939,7 @@ def _compute_softmax(scores): ...@@ -930,6 +939,7 @@ def _compute_softmax(scores):
def generate_tf_record_from_json_file(input_file_path, def generate_tf_record_from_json_file(input_file_path,
vocab_file_path, vocab_file_path,
output_path, output_path,
translated_input_folder=None,
max_seq_length=384, max_seq_length=384,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
...@@ -940,7 +950,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -940,7 +950,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=input_file_path, input_file=input_file_path,
is_training=True, is_training=True,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file_path, do_lower_case=do_lower_case) vocab_file=vocab_file_path, do_lower_case=do_lower_case)
train_writer = FeatureWriter(filename=output_path, is_training=True) train_writer = FeatureWriter(filename=output_path, is_training=True)
......
...@@ -109,12 +109,22 @@ class InputFeatures(object): ...@@ -109,12 +109,22 @@ class InputFeatures(object):
self.is_impossible = is_impossible self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training, version_2_with_negative): def read_squad_examples(input_file,
is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
del version_2_with_negative del version_2_with_negative
with tf.io.gfile.GFile(input_file, "r") as reader: with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
examples = [] examples = []
for entry in input_data: for entry in input_data:
for paragraph in entry["paragraphs"]: for paragraph in entry["paragraphs"]:
...@@ -922,6 +932,7 @@ class FeatureWriter(object): ...@@ -922,6 +932,7 @@ class FeatureWriter(object):
def generate_tf_record_from_json_file(input_file_path, def generate_tf_record_from_json_file(input_file_path,
sp_model_file, sp_model_file,
output_path, output_path,
translated_input_folder=None,
max_seq_length=384, max_seq_length=384,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
...@@ -932,7 +943,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -932,7 +943,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=input_file_path, input_file=input_file_path,
is_training=True, is_training=True,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullSentencePieceTokenizer( tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file) sp_model_file=sp_model_file)
train_writer = FeatureWriter( train_writer = FeatureWriter(
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute # A negative label id for the padding label, which will not contribute
...@@ -89,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor): ...@@ -89,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor):
"tr", "et", "fi", "hu" "tr", "et", "fi", "hu"
] ]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(PanxProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
return _read_one_file( examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels()) os.path.join(data_dir, "train-en.tsv"), self.get_labels())
if not self.only_use_en_train:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"train-{language}.tsv"),
self.get_labels()))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
return _read_one_file( examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels()) os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
if not self.only_use_en_dev:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"dev-{language}.tsv"),
self.get_labels()))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
examples_dict = {} examples_dict = {}
...@@ -120,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor): ...@@ -120,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor):
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh" "ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
] ]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(UdposProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
return _read_one_file( if self.only_use_en_train:
os.path.join(data_dir, "train-en.tsv"), self.get_labels()) examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
else:
examples = []
# Uses glob because some languages are missing in train.
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "train-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
return _read_one_file( if self.only_use_en_dev:
os.path.join(data_dir, "dev-en.tsv"), self.get_labels()) examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
else:
examples = []
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "dev-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
examples_dict = {} examples_dict = {}
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sentence prediction (classification) task.""" """Sentence prediction (classification) task."""
from typing import List, Union from typing import List, Union, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -159,8 +159,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -159,8 +159,7 @@ class SentencePredictionTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
logs.update({ logs.update({
'sentence_prediction': 'sentence_prediction': # Ensure one prediction along batch dimension.
# Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1), tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels': 'labels':
labels, labels,
...@@ -228,32 +227,34 @@ class SentencePredictionTask(base_task.Task): ...@@ -228,32 +227,34 @@ class SentencePredictionTask(base_task.Task):
ckpt_dir_or_file) ckpt_dir_or_file)
def predict(task: SentencePredictionTask, params: cfg.DataConfig, def predict(task: SentencePredictionTask,
model: tf.keras.Model) -> List[Union[int, float]]: params: cfg.DataConfig,
model: tf.keras.Model,
params_aug: Optional[cfg.DataConfig] = None,
test_time_aug_wgt: float = 0.3) -> List[Union[int, float]]:
"""Predicts on the input data. """Predicts on the input data.
Args: Args:
task: A `SentencePredictionTask` object. task: A `SentencePredictionTask` object.
params: A `cfg.DataConfig` object. params: A `cfg.DataConfig` object.
model: A keras.Model. model: A keras.Model.
params_aug: A `cfg.DataConfig` object for augmented data.
test_time_aug_wgt: Test time augmentation weight. The prediction score will
use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
augmented prediction.
Returns: Returns:
A list of predictions with length of `num_examples`. For regression task, A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task, each element in the list is the predicted score; for classification task,
each element is the predicted class id. each element is the predicted class id.
""" """
is_regression = task.task_config.model.num_classes == 1
def predict_step(inputs): def predict_step(inputs):
"""Replicated prediction calculation.""" """Replicated prediction calculation."""
x, _ = inputs x, _ = inputs
example_id = x.pop('example_id') example_id = x.pop('example_id')
outputs = task.inference_step(x, model) outputs = task.inference_step(x, model)
if is_regression: return dict(example_id=example_id, predictions=outputs)
return dict(example_id=example_id, predictions=outputs)
else:
return dict(
example_id=example_id, predictions=tf.argmax(outputs, axis=-1))
def aggregate_fn(state, outputs): def aggregate_fn(state, outputs):
"""Concatenates model's outputs.""" """Concatenates model's outputs."""
...@@ -272,4 +273,22 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig, ...@@ -272,4 +273,22 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
# When running on TPU POD, the order of output cannot be maintained, # When running on TPU POD, the order of output cannot be maintained,
# so we need to sort by example_id. # so we need to sort by example_id.
outputs = sorted(outputs, key=lambda x: x[0]) outputs = sorted(outputs, key=lambda x: x[0])
return [x[1] for x in outputs] is_regression = task.task_config.model.num_classes == 1
if params_aug is not None:
dataset_aug = orbit.utils.make_distributed_dataset(
tf.distribute.get_strategy(), task.build_inputs, params_aug)
outputs_aug = utils.predict(predict_step, aggregate_fn, dataset_aug)
outputs_aug = sorted(outputs_aug, key=lambda x: x[0])
if is_regression:
return [(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1]
for x, y in zip(outputs, outputs_aug)]
else:
return [
tf.argmax(
(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1],
axis=-1) for x, y in zip(outputs, outputs_aug)
]
if is_regression:
return [x[1] for x in outputs]
else:
return [tf.argmax(x[1], axis=-1) for x in outputs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment