"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e08c01aa1ad63efff83548ea69d5ba3ce4a75acc"
Commit a26ce4de authored by Stefan Schweter's avatar Stefan Schweter
Browse files

examples: add XLM-RoBERTa to glue script

parent fe9aab10
...@@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig, ...@@ -52,6 +52,9 @@ from transformers import (WEIGHTS_NAME, BertConfig,
AlbertConfig, AlbertConfig,
AlbertForSequenceClassification, AlbertForSequenceClassification,
AlbertTokenizer, AlbertTokenizer,
XLMRobertaConfig,
XLMRobertaForSequenceClassification,
XLMRobertaTokenizer,
) )
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
...@@ -72,7 +75,8 @@ MODEL_CLASSES = { ...@@ -72,7 +75,8 @@ MODEL_CLASSES = {
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer) 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
} }
...@@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -304,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else: else:
logger.info("Creating features from dataset file at %s", args.data_dir) logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels() label_list = processor.get_labels()
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
# HACK(label indices are swapped in RoBERTa pretrained model) # HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1] label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, features = convert_examples_to_features(examples,
tokenizer, tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment