Commit 571369aa authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add STS-B glue data preprocessor and regression support, and alpha-order the preprocessor classes.

PiperOrigin-RevId: 320318976
parent 7bb5ab6d
...@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization ...@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple seq regression/classification."""
def __init__(self, def __init__(self,
guid, guid,
...@@ -48,8 +48,9 @@ class InputExample(object): ...@@ -48,8 +48,9 @@ class InputExample(object):
sequence tasks, only this sequence must be specified. sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence. text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks. Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be label: (Optional) string for classification, float for regression. The
specified for train and dev examples, but not for test examples. label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during weight: (Optional) float. The weight of the example to be used during
training. training.
int_iden: (Optional) int. The int identification number of example in the int_iden: (Optional) int. The int identification number of example in the
...@@ -84,10 +85,12 @@ class InputFeatures(object): ...@@ -84,10 +85,12 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for converters for seq regression/classification datasets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn self.process_text_fn = process_text_fn
self.is_regression = False
self.label_type = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
...@@ -121,76 +124,70 @@ class DataProcessor(object): ...@@ -121,76 +124,70 @@ class DataProcessor(object):
return lines return lines
class XnliProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the CoLA data set (GLUE version)."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] return self._create_examples(
for language in self.languages: self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: # Only the test set has a header
if set_type == "test" and i == 0:
continue continue
guid = "dev-%d" % i guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[6]) if set_type == "test":
text_b = self.process_text_fn(line[7]) text_a = self.process_text_fn(line[1])
label = self.process_text_fn(line[1]) label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) return self._create_examples(
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -199,65 +196,69 @@ class XnliProcessor(DataProcessor): ...@@ -199,65 +196,69 @@ class XnliProcessor(DataProcessor):
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "XNLI" return "MNLI"
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "train-%d" % i if i == 0:
text_a = self.process_text_fn(line[0]) continue
text_b = self.process_text_fn(line[1]) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
label = self.process_text_fn(line[2]) text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) return self._create_examples(
examples = [] self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in self.supported_languages} return self._create_examples(
for lang in self.supported_languages: self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["0", "1"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "XTREME-XNLI" return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class PawsxProcessor(DataProcessor): class PawsxProcessor(DataProcessor):
...@@ -339,154 +340,8 @@ class PawsxProcessor(DataProcessor): ...@@ -339,154 +340,8 @@ class PawsxProcessor(DataProcessor):
return "XTREME-PAWS-X" return "XTREME-PAWS-X"
class XtremePawsxProcessor(DataProcessor): class QnliProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set.""" """Processor for the QNLI data set (GLUE version)."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -496,7 +351,7 @@ class QqpProcessor(DataProcessor): ...@@ -496,7 +351,7 @@ class QqpProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -505,12 +360,12 @@ class QqpProcessor(DataProcessor): ...@@ -505,12 +360,12 @@ class QqpProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1"] return ["entailment", "not_entailment"]
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "QQP" return "QNLI"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
...@@ -518,20 +373,22 @@ class QqpProcessor(DataProcessor): ...@@ -518,20 +373,22 @@ class QqpProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, 1)
try: if set_type == "test":
text_a = line[3] text_a = tokenization.convert_to_unicode(line[1])
text_b = line[4] text_b = tokenization.convert_to_unicode(line[2])
label = line[5] label = "entailment"
except IndexError: else:
continue text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class ColaProcessor(DataProcessor): class QqpProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the QQP data set (GLUE version)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -555,24 +412,23 @@ class ColaProcessor(DataProcessor): ...@@ -555,24 +412,23 @@ class ColaProcessor(DataProcessor):
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "COLA" return "QQP"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
# Only the test set has a header if i == 0:
if set_type == "test" and i == 0: continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -668,8 +524,14 @@ class SstProcessor(DataProcessor): ...@@ -668,8 +524,14 @@ class SstProcessor(DataProcessor):
return examples return examples
class QnliProcessor(DataProcessor): class StsBProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version).""" """Processor for the STS-B data set (GLUE version)."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
self.is_regression = True
self.label_type = float
self._labels = None
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -679,7 +541,7 @@ class QnliProcessor(DataProcessor): ...@@ -679,7 +541,7 @@ class QnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
...@@ -688,28 +550,26 @@ class QnliProcessor(DataProcessor): ...@@ -688,28 +550,26 @@ class QnliProcessor(DataProcessor):
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["entailment", "not_entailment"] return self._labels
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "QNLI" return "STS-B"
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for i, line in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, 1) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[7])
text_b = tokenization.convert_to_unicode(line[8])
if set_type == "test": if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1]) label = 0.0
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else: else:
text_a = tokenization.convert_to_unicode(line[1]) label = self.label_type(tokenization.convert_to_unicode(line[9]))
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -888,6 +748,200 @@ class WnliProcessor(DataProcessor): ...@@ -888,6 +748,200 @@ class WnliProcessor(DataProcessor):
return examples return examples
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def __init__(self,
language="en",
process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
if language == "all":
self.languages = XnliProcessor.supported_languages
elif language not in XnliProcessor.supported_languages:
raise ValueError("language %s is not supported for XNLI task." % language)
else:
self.languages = [language]
def get_train_examples(self, data_dir):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
language = self.process_text_fn(line[0])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples_by_lang[language].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XNLI"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class XtremeXnliProcessor(DataProcessor):
"""Processor for the XTREME XNLI data set."""
supported_languages = [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer): tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Converts a single `InputExample` into a single `InputFeatures`."""
......
...@@ -51,7 +51,8 @@ flags.DEFINE_string( ...@@ -51,7 +51,8 @@ flags.DEFINE_string(
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", ["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X"], "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The " "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -187,6 +188,8 @@ def generate_classifier_dataset(): ...@@ -187,6 +188,8 @@ def generate_classifier_dataset():
"rte": classifier_data_lib.RteProcessor, "rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli": "xnli":
functools.partial(classifier_data_lib.XnliProcessor, functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment