Unverified Commit fded6f41 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix integration with Accelerate and failing test (#24691)

Fix integration
parent bbf30908
......@@ -3176,13 +3176,19 @@ class Trainer:
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
all_losses = nested_numpify(losses_host)
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
all_preds = nested_numpify(preds_host)
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
all_inputs = nested_numpify(inputs_host)
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
all_labels = nested_numpify(labels_host)
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples
if has_length(eval_dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment