Commit 7647fbcf authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317211727
parent 79ae8004
......@@ -191,12 +191,68 @@ class XnliProcessor(DataProcessor):
return "XNLI"
class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set."""
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"de", "en", "es", "fr", "ja", "ko", "zh"
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
......@@ -219,8 +275,7 @@ class PawsxProcessor(DataProcessor):
train_tsv = "translated_train.tsv"
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, language, train_tsv))[1:])
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = []
for (i, line) in enumerate(lines):
......@@ -235,34 +290,30 @@ class PawsxProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
lines = []
for language in PawsxProcessor.supported_languages:
# Skips the header.
lines.extend(
self._read_tsv(os.path.join(data_dir, language, "dev_2k.tsv"))[1:])
for lang in PawsxProcessor.supported_languages:
lines.extend(self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv")))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
label = self.process_text_fn(line[3])
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in PawsxProcessor.supported_languages}
for language in PawsxProcessor.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, language, "test_2k.tsv"))
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
label = self.process_text_fn(line[3])
examples_by_lang[language].append(
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
......@@ -273,7 +324,62 @@ class PawsxProcessor(DataProcessor):
@staticmethod
def get_processor_name():
"""See base class."""
return "PAWS-X"
return "XTREME-PAWS-X"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev_en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor):
......@@ -407,8 +513,8 @@ class QqpProcessor(DataProcessor):
label = line[5]
except IndexError:
continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b,
label=label))
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -583,15 +689,16 @@ class TfdsProcessor(DataProcessor):
is_regression: Whether the task is a regression problem (defaults to False).
"""
def __init__(self, tfds_params,
def __init__(self,
tfds_params,
process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params)
if self.module_import:
importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True)
self.dataset, info = tfds.load(
self.dataset_name, data_dir=self.data_dir, with_info=True)
if self.is_regression:
self._labels = None
else:
......@@ -660,8 +767,12 @@ class TfdsProcessor(DataProcessor):
if self.weight_key:
weight = float(example[self.weight_key])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=weight))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
label=label,
weight=weight))
return examples
......@@ -761,9 +872,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return feature
def file_based_convert_examples_to_features(examples, label_list,
max_seq_length, tokenizer,
output_file, label_type=None):
def file_based_convert_examples_to_features(examples,
label_list,
max_seq_length,
tokenizer,
output_file,
label_type=None):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
......@@ -779,6 +893,7 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
......@@ -857,8 +972,7 @@ def generate_tf_record_from_data_file(processor,
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path,
label_type)
train_data_output_path, label_type)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
......@@ -873,10 +987,8 @@ def generate_tf_record_from_data_file(processor,
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features(
examples,
label_list, max_seq_length,
tokenizer, test_data_output_path.format(language),
label_type)
examples, label_list, max_seq_length, tokenizer,
test_data_output_path.format(language), label_type)
else:
file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length,
......
......@@ -48,8 +48,12 @@ flags.DEFINE_string(
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X"],
"The name of the task to train BERT classifier.")
"PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# XNLI task specific flag.
flags.DEFINE_string(
......@@ -176,7 +180,11 @@ def generate_classifier_dataset():
language=FLAGS.xnli_language),
"paws-x":
functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language)
language=FLAGS.pawsx_language),
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment