Unverified Commit 09a272b0 authored by Laura Vasquez-Rodriguez's avatar Laura Vasquez-Rodriguez Committed by GitHub
Browse files

Add inputs vector to calculate metric method (#16461)

* Add inputs vector to calculate metric method

* Include inputs for evaluation metrics with backwards compatibility

* Prevent inputs create OOM issue and documentation details

* Update style and code documentation

* Fix style formatting issues

* Update files format with make style
parent dc991805
...@@ -2437,10 +2437,13 @@ class Trainer: ...@@ -2437,10 +2437,13 @@ class Trainer:
losses_host = None losses_host = None
preds_host = None preds_host = None
labels_host = None labels_host = None
inputs_host = None
# losses/preds/labels on CPU (final containers) # losses/preds/labels on CPU (final containers)
all_losses = None all_losses = None
all_preds = None all_preds = None
all_labels = None all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length. # Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0 observed_num_examples = 0
...@@ -2456,6 +2459,7 @@ class Trainer: ...@@ -2456,6 +2459,7 @@ class Trainer:
# Prediction step # Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.mark_step() xm.mark_step()
...@@ -2468,6 +2472,14 @@ class Trainer: ...@@ -2468,6 +2472,14 @@ class Trainer:
labels = self._pad_across_processes(labels) labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels) labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode)
inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None: if logits is not None:
logits = self._pad_across_processes(logits) logits = self._pad_across_processes(logits)
logits = self._nested_gather(logits) logits = self._nested_gather(logits)
...@@ -2484,6 +2496,13 @@ class Trainer: ...@@ -2484,6 +2496,13 @@ class Trainer:
if preds_host is not None: if preds_host is not None:
logits = nested_numpify(preds_host) logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None: if labels_host is not None:
labels = nested_numpify(labels_host) labels = nested_numpify(labels_host)
all_labels = ( all_labels = (
...@@ -2491,7 +2510,7 @@ class Trainer: ...@@ -2491,7 +2510,7 @@ class Trainer:
) )
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None losses_host, preds_host, inputs_host, labels_host = None, None, None, None
if args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
...@@ -2504,6 +2523,11 @@ class Trainer: ...@@ -2504,6 +2523,11 @@ class Trainer:
if preds_host is not None: if preds_host is not None:
logits = nested_numpify(preds_host) logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None: if labels_host is not None:
labels = nested_numpify(labels_host) labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
...@@ -2529,9 +2553,16 @@ class Trainer: ...@@ -2529,9 +2553,16 @@ class Trainer:
all_preds = nested_truncate(all_preds, num_samples) all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None: if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples) all_labels = nested_truncate(all_labels, num_samples)
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
# Metrics! # Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else: else:
metrics = {} metrics = {}
...@@ -2913,7 +2944,6 @@ class Trainer: ...@@ -2913,7 +2944,6 @@ class Trainer:
# if eval is called w/o train init deepspeed here # if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed: if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually # from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
...@@ -2944,6 +2974,7 @@ class Trainer: ...@@ -2944,6 +2974,7 @@ class Trainer:
losses_host: torch.Tensor = None losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = max(1, args.world_size) world_size = max(1, args.world_size)
...@@ -2956,6 +2987,7 @@ class Trainer: ...@@ -2956,6 +2987,7 @@ class Trainer:
make_multiple_of = dataloader.sampler.batch_size make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval() model.eval()
...@@ -2969,6 +3001,8 @@ class Trainer: ...@@ -2969,6 +3001,8 @@ class Trainer:
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
if loss is not None: if loss is not None:
losses = loss.repeat(batch_size) losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
...@@ -2976,6 +3010,12 @@ class Trainer: ...@@ -2976,6 +3010,12 @@ class Trainer:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None: if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
...@@ -2984,9 +3024,10 @@ class Trainer: ...@@ -2984,9 +3024,10 @@ class Trainer:
if not prediction_loss_only: if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None losses_host, preds_host, labels_host, inputs_host = None, None, None, None
if args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
...@@ -2997,12 +3038,19 @@ class Trainer: ...@@ -2997,12 +3038,19 @@ class Trainer:
if not prediction_loss_only: if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
eval_loss = eval_losses_gatherer.finalize() eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize() if not prediction_loss_only else None preds = preds_gatherer.finalize() if not prediction_loss_only else None
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else: else:
metrics = {} metrics = {}
......
...@@ -63,17 +63,43 @@ def set_seed(seed: int): ...@@ -63,17 +63,43 @@ def set_seed(seed: int):
tf.random.set_seed(seed) tf.random.set_seed(seed)
class EvalPrediction(NamedTuple): class EvalPrediction:
""" """
Evaluation output (always contains labels), to be used to compute metrics. Evaluation output (always contains labels), to be used to compute metrics.
Parameters: Parameters:
predictions (`np.ndarray`): Predictions of the model. predictions (`np.ndarray`): Predictions of the model.
label_ids (`np.ndarray`): Targets to be matched. label_ids (`np.ndarray`): Targets to be matched.
inputs (`np.ndarray`, *optional*)
""" """
predictions: Union[np.ndarray, Tuple[np.ndarray]] def __init__(
label_ids: Union[np.ndarray, Tuple[np.ndarray]] self,
predictions: Union[np.ndarray, Tuple[np.ndarray]],
label_ids: Union[np.ndarray, Tuple[np.ndarray]],
inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
):
self.predictions = predictions
self.label_ids = label_ids
self.inputs = inputs
def __iter__(self):
if self.inputs is not None:
return iter((self.predictions, self.label_ids, self.inputs))
else:
return iter((self.predictions, self.label_ids))
def __getitem__(self, idx):
if idx < 0 or idx > 2:
raise IndexError("tuple index out of range")
if idx == 2 and self.inputs is None:
raise IndexError("tuple index out of range")
if idx == 0:
return self.predictions
elif idx == 1:
return self.label_ids
elif idx == 2:
return self.inputs
class EvalLoopOutput(NamedTuple): class EvalLoopOutput(NamedTuple):
......
...@@ -416,6 +416,9 @@ class TrainingArguments: ...@@ -416,6 +416,9 @@ class TrainingArguments:
`huggingface-cli login`. `huggingface-cli login`.
gradient_checkpointing (`bool`, *optional*, defaults to `False`): gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -739,6 +742,9 @@ class TrainingArguments: ...@@ -739,6 +742,9 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
}, },
) )
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
# Deprecated arguments # Deprecated arguments
fp16_backend: str = field( fp16_backend: str = field(
default="auto", default="auto",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment