"vscode:/vscode.git/clone" did not exist on "088ad458885bb41694deb4e52bea914087a64dad"
Unverified Commit 534cbf8a authored by Willard Sheen's avatar Willard Sheen Committed by GitHub
Browse files

[fix bug] logits's shape different from label's shape in preprocess_logits_for_metrics (#31447)

* [fix BUG] pad labels before use it in preprocess_logits_for_metrics

* a more readable fix

labels can't use  `gather` before pass to `preprocess_logits_for_metrics`, so must split into 2 if-block

* add a comment

* oh code quality check
parent 65a02cd2
......@@ -3839,6 +3839,9 @@ class Trainer:
inputs_decode = self.gather_function((inputs_decode))
if not self.args.batch_eval_metrics or description == "Prediction":
all_inputs.add(inputs_decode)
if labels is not None:
# Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
......@@ -3847,7 +3850,6 @@ class Trainer:
if not self.args.batch_eval_metrics or description == "Prediction":
all_preds.add(logits)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels))
if not self.args.batch_eval_metrics or description == "Prediction":
all_labels.add(labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment