Unverified Commit ad98642a authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fix gather for metrics (#19360)

parent d9101b71
......@@ -685,7 +685,7 @@ def main():
# If we did not pad to max length, we need to pad the labels too
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels)
generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
generated_tokens = generated_tokens.cpu().numpy()
labels = labels.cpu().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment