"test/ut/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b7f374cebc8dbbd0ab7420e9d6fe53bbabe7f713"
Unverified Commit 87dd1a00 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix metric computation in `run_glue_no_trainer` (#11569)

parent a721a5ee
...@@ -404,7 +404,7 @@ def main(): ...@@ -404,7 +404,7 @@ def main():
model.eval() model.eval()
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch( metric.add_batch(
predictions=accelerator.gather(predictions), predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["labels"]), references=accelerator.gather(batch["labels"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment