Commit 75635072 authored by LysandreJik's avatar LysandreJik
Browse files

Updated GLUE script to add DistilBERT. Cleaned up unused args in the utils file.

parent 92a9976e
...@@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, ...@@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
XLMConfig, XLMForSequenceClassification, XLMConfig, XLMForSequenceClassification,
XLMTokenizer, XLNetConfig, XLMTokenizer, XLNetConfig,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetTokenizer) XLNetTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule from pytorch_transformers import AdamW, WarmupLinearSchedule
...@@ -55,6 +58,7 @@ MODEL_CLASSES = { ...@@ -55,6 +58,7 @@ MODEL_CLASSES = {
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
} }
...@@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
...@@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
outputs = model(**inputs) outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2] tmp_eval_loss, logits = outputs[:2]
...@@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
label_list[1], label_list[2] = label_list[2], label_list[1] label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
......
...@@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor): ...@@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode, tokenizer, output_mode,
cls_token_at_end=False,
cls_token='[CLS]',
cls_token_segment_id=1,
sep_token='[SEP]',
sep_token_extra=False,
pad_on_left=False, pad_on_left=False,
pad_token=0, pad_token=0,
pad_token_segment_id=0, pad_token_segment_id=0,
sequence_a_segment_id=0,
sequence_b_segment_id=1,
mask_padding_with_zero=True): mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """
`cls_token_at_end` define the location of the CLS token: Loads a data file into a list of `InputBatch`s
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
""" """
label_map = {label : i for i, label in enumerate(label_list)} label_map = {label : i for i, label in enumerate(label_list)}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment