"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "78284e49d7c19a6890995695bbf840f29e5c7dde"
Unverified Commit 8e0bcb56 authored by Qingqing Cao's avatar Qingqing Cao Committed by GitHub
Browse files

DataParallel fix: multi gpu evaluation (#5926)

The DataParallel training was fixed in https://github.com/huggingface/transformers/pull/5733, this commit also fixes the evaluation. It's more convenient when the user enables both `do_train` and `do_eval`.
parent a2096917
...@@ -316,7 +316,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -316,7 +316,8 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs.update( inputs.update(
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
) )
if isinstance(model, torch.nn.DataParallel):
inputs["return_tuple"] = True
outputs = model(**inputs) outputs = model(**inputs)
for i, feature_index in enumerate(feature_indices): for i, feature_index in enumerate(feature_indices):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment