Commit bd847ce7 authored by focox@qq.com's avatar focox@qq.com
Browse files

fixed the bug raised by "tmp_eval_loss += tmp_eval_loss.item()" when parallelly using multi-gpu.

parent ef1b8b2a
......@@ -210,6 +210,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
if args.n_gpu > 1:
tmp_eval_loss = tmp_eval_loss.mean() # mean() to average on multi-gpu parallel evaluating
eval_loss += tmp_eval_loss.item()
nb_eval_steps += 1
if preds is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment