Commit 431a9ca3 authored by stephenwu's avatar stephenwu
Browse files

added AX-g preprocessor

parent 80993c41
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import csv import csv
import importlib import importlib
import os import os
import json
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return feature return feature
class AXgProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "dev.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "AXg"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line['idx'])))
text_a = self.process_text_fn(line["hypothesis"])
text_b = self.process_text_fn(line["premise"])
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def _read_jsonl(self, input_path):
with tf.io.gfile.GFile(input_path, "r") as f:
lines = []
for json_str in f:
lines.append(json.loads(json_str))
return lines
def file_based_convert_examples_to_features(examples, def file_based_convert_examples_to_features(examples,
label_list, label_list,
...@@ -1374,8 +1415,9 @@ def generate_tf_record_from_data_file(processor, ...@@ -1374,8 +1415,9 @@ def generate_tf_record_from_data_file(processor,
label_type = getattr(processor, "label_type", None) label_type = getattr(processor, "label_type", None)
is_regression = getattr(processor, "is_regression", False) is_regression = getattr(processor, "is_regression", False)
has_sample_weights = getattr(processor, "weight_key", False) has_sample_weights = getattr(processor, "weight_key", False)
assert train_data_output_path
num_training_data = 0
if train_data_output_path:
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer, max_seq_length, tokenizer,
......
...@@ -49,7 +49,7 @@ flags.DEFINE_string( ...@@ -49,7 +49,7 @@ flags.DEFINE_string(
flags.DEFINE_enum( flags.DEFINE_enum(
"classification_task_name", "MNLI", [ "classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X" "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", "AX-g"
], "The name of the task to train BERT classifier. The " ], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -238,7 +238,10 @@ def generate_classifier_dataset(): ...@@ -238,7 +238,10 @@ def generate_classifier_dataset():
functools.partial( functools.partial(
classifier_data_lib.XtremePawsxProcessor, classifier_data_lib.XtremePawsxProcessor,
translated_data_dir=FLAGS.translated_input_data_dir, translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev) only_use_en_dev=FLAGS.only_use_en_dev),
"ax-g":
classifier_data_lib.AXgProcessor
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment