"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "377cd9e468499d74abbe46a08886d63d3389235b"
Commit 75635072 authored by LysandreJik's avatar LysandreJik
Browse files

Updated GLUE script to add DistilBERT. Cleaned up unused args in the utils file.

parent 92a9976e
...@@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, ...@@ -39,7 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
XLMConfig, XLMForSequenceClassification, XLMConfig, XLMForSequenceClassification,
XLMTokenizer, XLNetConfig, XLMTokenizer, XLNetConfig,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetTokenizer) XLNetTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule from pytorch_transformers import AdamW, WarmupLinearSchedule
...@@ -55,6 +58,7 @@ MODEL_CLASSES = { ...@@ -55,6 +58,7 @@ MODEL_CLASSES = {
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
} }
...@@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -128,7 +132,7 @@ def train(args, train_dataset, model, tokenizer):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
...@@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -218,7 +222,7 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
outputs = model(**inputs) outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2] tmp_eval_loss, logits = outputs[:2]
...@@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -273,11 +277,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
label_list[1], label_list[2] = label_list[2], label_list[1] label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
......
...@@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor): ...@@ -390,22 +390,12 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode, tokenizer, output_mode,
cls_token_at_end=False,
cls_token='[CLS]',
cls_token_segment_id=1,
sep_token='[SEP]',
sep_token_extra=False,
pad_on_left=False, pad_on_left=False,
pad_token=0, pad_token=0,
pad_token_segment_id=0, pad_token_segment_id=0,
sequence_a_segment_id=0,
sequence_b_segment_id=1,
mask_padding_with_zero=True): mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """
`cls_token_at_end` define the location of the CLS token: Loads a data file into a list of `InputBatch`s
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
""" """
label_map = {label : i for i, label in enumerate(label_list)} label_map = {label : i for i, label in enumerate(label_list)}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment