Commit 7647fbcf authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317211727
parent 79ae8004
...@@ -191,12 +191,68 @@ class XnliProcessor(DataProcessor): ...@@ -191,12 +191,68 @@ class XnliProcessor(DataProcessor):
return "XNLI" return "XNLI"
class PawsxProcessor(DataProcessor): class XtremeXnliProcessor(DataProcessor):
"""Processor for the PAWS-X data set.""" """Processor for the XTREME XNLI data set."""
supported_languages = [ supported_languages = [
"de", "en", "es", "fr", "ja", "ko", "zh" "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
] ]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self, def __init__(self,
language="en", language="en",
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
...@@ -219,8 +275,7 @@ class PawsxProcessor(DataProcessor): ...@@ -219,8 +275,7 @@ class PawsxProcessor(DataProcessor):
train_tsv = "translated_train.tsv" train_tsv = "translated_train.tsv"
# Skips the header. # Skips the header.
lines.extend( lines.extend(
self._read_tsv( self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
...@@ -235,34 +290,30 @@ class PawsxProcessor(DataProcessor): ...@@ -235,34 +290,30 @@ class PawsxProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] lines = []
for language in PawsxProcessor.supported_languages: for lang in PawsxProcessor.supported_languages:
# Skips the header. lines.extend(self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv")))
lines.extend(
self._read_tsv(os.path.join(data_dir, language, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "dev-%d" % i guid = "dev-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[3]) label = self.process_text_fn(line[2])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in PawsxProcessor.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for language in PawsxProcessor.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, language, "test_2k.tsv")) lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[3]) label = self.process_text_fn(line[2])
examples_by_lang[language].append( examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang return examples_by_lang
...@@ -273,7 +324,62 @@ class PawsxProcessor(DataProcessor): ...@@ -273,7 +324,62 @@ class PawsxProcessor(DataProcessor):
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "PAWS-X" return "XTREME-PAWS-X"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev_en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
...@@ -407,8 +513,8 @@ class QqpProcessor(DataProcessor): ...@@ -407,8 +513,8 @@ class QqpProcessor(DataProcessor):
label = line[5] label = line[5]
except IndexError: except IndexError:
continue continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, examples.append(
label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -583,15 +689,16 @@ class TfdsProcessor(DataProcessor): ...@@ -583,15 +689,16 @@ class TfdsProcessor(DataProcessor):
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
""" """
def __init__(self, tfds_params, def __init__(self,
tfds_params,
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn) super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params) self._process_tfds_params_str(tfds_params)
if self.module_import: if self.module_import:
importlib.import_module(self.module_import) importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir, self.dataset, info = tfds.load(
with_info=True) self.dataset_name, data_dir=self.data_dir, with_info=True)
if self.is_regression: if self.is_regression:
self._labels = None self._labels = None
else: else:
...@@ -660,7 +767,11 @@ class TfdsProcessor(DataProcessor): ...@@ -660,7 +767,11 @@ class TfdsProcessor(DataProcessor):
if self.weight_key: if self.weight_key:
weight = float(example[self.weight_key]) weight = float(example[self.weight_key])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
label=label,
weight=weight)) weight=weight))
return examples return examples
...@@ -761,9 +872,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -761,9 +872,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return feature return feature
def file_based_convert_examples_to_features(examples, label_list, def file_based_convert_examples_to_features(examples,
max_seq_length, tokenizer, label_list,
output_file, label_type=None): max_seq_length,
tokenizer,
output_file,
label_type=None):
"""Convert a set of `InputExample`s to a TFRecord file.""" """Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
...@@ -779,6 +893,7 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -779,6 +893,7 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f return f
def create_float_feature(values): def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f return f
...@@ -857,8 +972,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -857,8 +972,7 @@ def generate_tf_record_from_data_file(processor,
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer, max_seq_length, tokenizer,
train_data_output_path, train_data_output_path, label_type)
label_type)
num_training_data = len(train_input_data_examples) num_training_data = len(train_input_data_examples)
if eval_data_output_path: if eval_data_output_path:
...@@ -873,10 +987,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -873,10 +987,8 @@ def generate_tf_record_from_data_file(processor,
if isinstance(test_input_data_examples, dict): if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items(): for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features( file_based_convert_examples_to_features(
examples, examples, label_list, max_seq_length, tokenizer,
label_list, max_seq_length, test_data_output_path.format(language), label_type)
tokenizer, test_data_output_path.format(language),
label_type)
else: else:
file_based_convert_examples_to_features(test_input_data_examples, file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
......
...@@ -48,8 +48,12 @@ flags.DEFINE_string( ...@@ -48,8 +48,12 @@ flags.DEFINE_string(
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X"], "PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# XNLI task specific flag. # XNLI task specific flag.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -176,7 +180,11 @@ def generate_classifier_dataset(): ...@@ -176,7 +180,11 @@ def generate_classifier_dataset():
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
"paws-x": "paws-x":
functools.partial(classifier_data_lib.PawsxProcessor, functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language) language=FLAGS.pawsx_language),
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment