Unverified Commit 4965aee0 authored by Victor SANH's avatar Victor SANH Committed by GitHub
Browse files

[HANS] Fix label_list for RoBERTa/BART (class flipping) (#5196)

* fix weirdness in roberta/bart for mnli trained checkpoints

* black compliance

* isort code check
parent fc24a93e
...@@ -33,7 +33,7 @@ from transformers import ( ...@@ -33,7 +33,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from utils_hans import HansDataset, InputFeatures, hans_processors from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -130,9 +130,7 @@ def main(): ...@@ -130,9 +130,7 @@ def main():
set_seed(training_args.seed) set_seed(training_args.seed)
try: try:
processor = hans_processors[data_args.task_name]() num_labels = hans_tasks_num_labels[data_args.task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
except KeyError: except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name)) raise ValueError("Task not found: %s" % (data_args.task_name))
...@@ -214,6 +212,7 @@ def main(): ...@@ -214,6 +212,7 @@ def main():
pair_ids = [ex.pairID for ex in eval_dataset] pair_ids = [ex.pairID for ex in eval_dataset]
output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt") output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt")
label_list = eval_dataset.get_labels()
if trainer.is_world_master(): if trainer.is_world_master():
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
writer.write("pairID,gold_label\n") writer.write("pairID,gold_label\n")
......
...@@ -22,7 +22,17 @@ from typing import List, Optional, Union ...@@ -22,7 +22,17 @@ from typing import List, Optional, Union
import tqdm import tqdm
from filelock import FileLock from filelock import FileLock
from transformers import DataProcessor, PreTrainedTokenizer, is_tf_available, is_torch_available from transformers import (
BartTokenizer,
BartTokenizerFast,
DataProcessor,
PreTrainedTokenizer,
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
is_tf_available,
is_torch_available,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -105,6 +115,17 @@ if is_torch_available(): ...@@ -105,6 +115,17 @@ if is_torch_available():
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(max_seq_length), task, "dev" if evaluate else "train", tokenizer.__class__.__name__, str(max_seq_length), task,
), ),
) )
label_list = processor.get_labels()
if tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
BartTokenizer,
BartTokenizerFast,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
self.label_list = label_list
# Make sure only the first process in distributed training processes the dataset, # Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache. # and the others will use the cache.
...@@ -116,7 +137,6 @@ if is_torch_available(): ...@@ -116,7 +137,6 @@ if is_torch_available():
self.features = torch.load(cached_features_file) self.features = torch.load(cached_features_file)
else: else:
logger.info(f"Creating features from dataset file at {data_dir}") logger.info(f"Creating features from dataset file at {data_dir}")
label_list = processor.get_labels()
examples = ( examples = (
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir) processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
...@@ -133,6 +153,9 @@ if is_torch_available(): ...@@ -133,6 +153,9 @@ if is_torch_available():
def __getitem__(self, i) -> InputFeatures: def __getitem__(self, i) -> InputFeatures:
return self.features[i] return self.features[i]
def get_labels(self):
return self.label_list
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -156,6 +179,16 @@ if is_tf_available(): ...@@ -156,6 +179,16 @@ if is_tf_available():
): ):
processor = hans_processors[task]() processor = hans_processors[task]()
label_list = processor.get_labels() label_list = processor.get_labels()
if tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
BartTokenizer,
BartTokenizerFast,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
self.label_list = label_list
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir) examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer) self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
...@@ -206,6 +239,9 @@ if is_tf_available(): ...@@ -206,6 +239,9 @@ if is_tf_available():
def __getitem__(self, i) -> InputFeatures: def __getitem__(self, i) -> InputFeatures:
return self.features[i] return self.features[i]
def get_labels(self):
return self.label_list
class HansProcessor(DataProcessor): class HansProcessor(DataProcessor):
"""Processor for the HANS data set.""" """Processor for the HANS data set."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment