"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "54a31f50fb9403eb11c8b85057493db284f13f99"
Commit baf08ca1 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[RoBERTa] run_glue: correct pad_token + reorder labels

parent 3d87991f
...@@ -268,6 +268,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -268,6 +268,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else: else:
logger.info("Creating features from dataset file at %s", args.data_dir) logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels() label_list = processor.get_labels()
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
...@@ -276,7 +279,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -276,7 +279,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
sep_token=tokenizer.sep_token, sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) pad_token=1 if args.model_type in ['roberta'] else 0, # TODO(Lysandre: replace with tokenizer.pad_token when implemented)
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment