Commit 702f5898 authored by VictorSanh's avatar VictorSanh
Browse files

fix input in run_glue for distilbert

parent 22d2fded
...@@ -134,8 +134,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -134,8 +134,9 @@ def train(args, train_dataset, model, tokenizer):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) loss = outputs[0] # model outputs are always tuple in transformers (see doc)
...@@ -224,8 +225,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -224,8 +225,9 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
outputs = model(**inputs) outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2] tmp_eval_loss, logits = outputs[:2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment