Commit 71d597da authored by thomwolf's avatar thomwolf
Browse files

fix #800

parent 4bcddf6f
...@@ -122,8 +122,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -122,8 +122,8 @@ def train(args, train_dataset, model, tokenizer):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids 'attention_mask': batch[1],
'attention_mask': batch[2], 'token_type_ids': None if args.model_type == 'xlm' else batch[2],
'start_positions': batch[3], 'start_positions': batch[3],
'end_positions': batch[4]} 'end_positions': batch[4]}
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
...@@ -206,8 +206,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -206,8 +206,9 @@ def evaluate(args, model, tokenizer, prefix=""):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids 'attention_mask': batch[1],
'attention_mask': batch[2]} 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
}
example_indices = batch[3] example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[4], inputs.update({'cls_index': batch[4],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment