Commit 431a9ca3 authored by stephenwu's avatar stephenwu
Browse files

added AX-g preprocessor

parent 80993c41
......@@ -18,6 +18,7 @@ import collections
import csv
import importlib
import os
import json
from absl import logging
import tensorflow as tf
......@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return feature
class AXgProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "dev.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "AXg"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line['idx'])))
text_a = self.process_text_fn(line["hypothesis"])
text_b = self.process_text_fn(line["premise"])
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def _read_jsonl(self, input_path):
with tf.io.gfile.GFile(input_path, "r") as f:
lines = []
for json_str in f:
lines.append(json.loads(json_str))
return lines
def file_based_convert_examples_to_features(examples,
label_list,
......@@ -1374,13 +1415,14 @@ def generate_tf_record_from_data_file(processor,
label_type = getattr(processor, "label_type", None)
is_regression = getattr(processor, "is_regression", False)
has_sample_weights = getattr(processor, "weight_key", False)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path, label_type)
num_training_data = len(train_input_data_examples)
num_training_data = 0
if train_data_output_path:
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path, label_type)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(data_dir)
......
......@@ -49,7 +49,7 @@ flags.DEFINE_string(
flags.DEFINE_enum(
"classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X"
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g"
], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
......@@ -238,7 +238,10 @@ def generate_classifier_dataset():
functools.partial(
classifier_data_lib.XtremePawsxProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev)
only_use_en_dev=FLAGS.only_use_en_dev),
"ax-g":
classifier_data_lib.AXgProcessor
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment