Unverified Commit 330247ed authored by Kian Sierra McGettigan's avatar Kian Sierra McGettigan Committed by GitHub
Browse files

Update no trainer scripts for multiple-choice (#18468)

* swag_no_trainer updated for with gather_metrics

* Removed unused variable samples_seen
parent c74befc9
...@@ -592,19 +592,11 @@ def main(): ...@@ -592,19 +592,11 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=predictions, predictions=predictions,
references=references, references=references,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment