Commit e9057c4d authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Add supports of using translated data in XTREME benchmarks.

PiperOrigin-RevId: 342448071
parent f409e4d0
......@@ -938,45 +938,104 @@ class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training and testing data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremePawsxProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is None:
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.only_use_en_dev:
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
examples_by_lang = {}
for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = "test-%d" % i
guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "0"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
......@@ -996,45 +1055,111 @@ class XtremeXnliProcessor(DataProcessor):
"ur", "vi", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremeXnliProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is None:
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.only_use_en_dev:
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
examples_by_lang = {}
for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"test-{i}"
guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "contradiction"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
......
......@@ -46,20 +46,19 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI",
"QQP", "RTE", "SST-2", "STS-B", "WNLI", "XNLI",
"XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
flags.DEFINE_enum(
"classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X"
], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# MNLI task-specific flag.
flags.DEFINE_enum(
"mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string(
......@@ -73,6 +72,12 @@ flags.DEFINE_string(
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags.DEFINE_string(
"translated_input_data_dir", None,
"The translated input data dir. Should contain the .tsv files (or other "
"data files) for the task.")
# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
......@@ -81,11 +86,19 @@ flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
flags.DEFINE_bool("tagging_only_use_en_train", True,
"Whether only use english training data in tagging.")
# BERT Squad task-specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_string(
"translated_squad_data_folder", None,
"The translated data folder for generating training data for BERT squad "
"task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
......@@ -105,6 +118,9 @@ flags.DEFINE_bool(
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order.")
# XTREME specific flags.
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
......@@ -148,16 +164,16 @@ flags.DEFINE_enum(
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses SentencePiece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
flags.DEFINE_string(
"tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
FLAGS.tfds_params)
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
......@@ -171,8 +187,7 @@ def generate_classifier_dataset():
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
......@@ -190,29 +205,40 @@ def generate_classifier_dataset():
"imdb":
classifier_data_lib.ImdbProcessor,
"mnli":
functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type),
functools.partial(
classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"qqp":
classifier_data_lib.QqpProcessor,
"rte":
classifier_data_lib.RteProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
functools.partial(
classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
"paws-x":
functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
functools.partial(
classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language),
"wnli":
classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
functools.partial(
classifier_data_lib.XtremeXnliProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
functools.partial(
classifier_data_lib.XtremePawsxProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev)
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......@@ -243,8 +269,7 @@ def generate_regression_dataset():
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
......@@ -265,6 +290,7 @@ def generate_squad_dataset():
input_file_path=FLAGS.squad_data_file,
vocab_file_path=FLAGS.vocab_file,
output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
......@@ -277,6 +303,7 @@ def generate_squad_dataset():
input_file_path=FLAGS.squad_data_file,
sp_model_file=FLAGS.sp_model_file,
output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
......@@ -310,19 +337,23 @@ def generate_retrieval_dataset():
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
"panx":
functools.partial(
tagging_data_lib.PanxProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
"udpos":
functools.partial(
tagging_data_lib.UdposProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
......
......@@ -158,11 +158,20 @@ class FeatureWriter(object):
self._writer.close()
def read_squad_examples(input_file, is_training, version_2_with_negative):
def read_squad_examples(input_file, is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
......@@ -930,6 +939,7 @@ def _compute_softmax(scores):
def generate_tf_record_from_json_file(input_file_path,
vocab_file_path,
output_path,
translated_input_folder=None,
max_seq_length=384,
do_lower_case=True,
max_query_length=64,
......@@ -940,7 +950,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples = read_squad_examples(
input_file=input_file_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file_path, do_lower_case=do_lower_case)
train_writer = FeatureWriter(filename=output_path, is_training=True)
......
......@@ -109,12 +109,22 @@ class InputFeatures(object):
self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training, version_2_with_negative):
def read_squad_examples(input_file,
is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample."""
del version_2_with_negative
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
......@@ -922,6 +932,7 @@ class FeatureWriter(object):
def generate_tf_record_from_json_file(input_file_path,
sp_model_file,
output_path,
translated_input_folder=None,
max_seq_length=384,
do_lower_case=True,
max_query_length=64,
......@@ -932,7 +943,8 @@ def generate_tf_record_from_json_file(input_file_path,
train_examples = read_squad_examples(
input_file=input_file_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file)
train_writer = FeatureWriter(
......
......@@ -19,6 +19,7 @@ import os
from absl import logging
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute
......@@ -89,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor):
"tr", "et", "fi", "hu"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(PanxProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
return _read_one_file(
examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
if not self.only_use_en_train:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"train-{language}.tsv"),
self.get_labels()))
return examples
def get_dev_examples(self, data_dir):
return _read_one_file(
examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
if not self.only_use_en_dev:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"dev-{language}.tsv"),
self.get_labels()))
return examples
def get_test_examples(self, data_dir):
examples_dict = {}
......@@ -120,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor):
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Arguments:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(UdposProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
if self.only_use_en_train:
examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
else:
examples = []
# Uses glob because some languages are missing in train.
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "train-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
if self.only_use_en_dev:
examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
else:
examples = []
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "dev-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_test_examples(self, data_dir):
examples_dict = {}
......
......@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
from typing import List, Union
from typing import List, Union, Optional
from absl import logging
import dataclasses
......@@ -159,8 +159,7 @@ class SentencePredictionTask(base_task.Task):
logs = {self.loss: loss}
if self.metric_type == 'matthews_corrcoef':
logs.update({
'sentence_prediction':
# Ensure one prediction along batch dimension.
'sentence_prediction': # Ensure one prediction along batch dimension.
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
'labels':
labels,
......@@ -228,32 +227,34 @@ class SentencePredictionTask(base_task.Task):
ckpt_dir_or_file)
def predict(task: SentencePredictionTask, params: cfg.DataConfig,
model: tf.keras.Model) -> List[Union[int, float]]:
def predict(task: SentencePredictionTask,
params: cfg.DataConfig,
model: tf.keras.Model,
params_aug: Optional[cfg.DataConfig] = None,
test_time_aug_wgt: float = 0.3) -> List[Union[int, float]]:
"""Predicts on the input data.
Args:
task: A `SentencePredictionTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
params_aug: A `cfg.DataConfig` object for augmented data.
test_time_aug_wgt: Test time augmentation weight. The prediction score will
use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
augmented prediction.
Returns:
A list of predictions with length of `num_examples`. For regression task,
each element in the list is the predicted score; for classification task,
each element is the predicted class id.
"""
is_regression = task.task_config.model.num_classes == 1
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
example_id = x.pop('example_id')
outputs = task.inference_step(x, model)
if is_regression:
return dict(example_id=example_id, predictions=outputs)
else:
return dict(
example_id=example_id, predictions=tf.argmax(outputs, axis=-1))
return dict(example_id=example_id, predictions=outputs)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
......@@ -272,4 +273,22 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
# When running on TPU POD, the order of output cannot be maintained,
# so we need to sort by example_id.
outputs = sorted(outputs, key=lambda x: x[0])
return [x[1] for x in outputs]
is_regression = task.task_config.model.num_classes == 1
if params_aug is not None:
dataset_aug = orbit.utils.make_distributed_dataset(
tf.distribute.get_strategy(), task.build_inputs, params_aug)
outputs_aug = utils.predict(predict_step, aggregate_fn, dataset_aug)
outputs_aug = sorted(outputs_aug, key=lambda x: x[0])
if is_regression:
return [(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1]
for x, y in zip(outputs, outputs_aug)]
else:
return [
tf.argmax(
(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1],
axis=-1) for x, y in zip(outputs, outputs_aug)
]
if is_regression:
return [x[1] for x in outputs]
else:
return [tf.argmax(x[1], axis=-1) for x in outputs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment