Commit 76be189b authored by Peiqin Lin's avatar Peiqin Lin
Browse files

typos

parent a6154990
......@@ -116,8 +116,8 @@ def train(args, train_dataset, model, tokenizer):
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
......
......@@ -129,8 +129,8 @@ def train(args, train_dataset, model, tokenizer):
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
ouputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment