Unverified Commit 6bd8ae26 authored by Liu Chenyang's avatar Liu Chenyang Committed by GitHub
Browse files

move preprocess_logits_for_metrics before _nested_gather in trainer.e… (#22603)



* move preprocess_logits_for_metrics before _nested_gather in trainer.evaluation_loop

* fix

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* fix

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c582e8aa
...@@ -3182,8 +3182,6 @@ class Trainer: ...@@ -3182,8 +3182,6 @@ class Trainer:
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None: if labels is not None:
labels = self._pad_across_processes(labels) labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode) inputs_decode = self._pad_across_processes(inputs_decode)
inputs_decode = self._nested_gather(inputs_decode) inputs_decode = self._nested_gather(inputs_decode)
...@@ -3194,10 +3192,13 @@ class Trainer: ...@@ -3194,10 +3192,13 @@ class Trainer:
) )
if logits is not None: if logits is not None:
logits = self._pad_across_processes(logits) logits = self._pad_across_processes(logits)
logits = self._nested_gather(logits)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self._nested_gather(logits)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment