Unverified Commit df475bf8 authored by Nate Cibik's avatar Nate Cibik Committed by GitHub
Browse files

Trainer - add cache clearing and the option for batched eval metrics computation (#28769)

* Added cache clearing for GPU efficiency.

* Added cache clearing for GPU efficiency.

* Added batch_eval_metrics capability

* Ran make fixup

* Fixed bug

* Fixed whitespace issue

* Fixed outdated condition

* Updated docstrings with instructions for batch_eval_metrics. Updated end of dataloader logic

* Added first version of batch_eval_metrics Trainer test

* Fixed batch_eval_metrics Trainer tests for both eval and predict

* Fixed batch_eval_metrics behavior for new Trainer variables

* Fixed batch_eval_metrics Trainer tests

* Ran fixup
parent e0769530
...@@ -327,7 +327,10 @@ class Trainer: ...@@ -327,7 +327,10 @@ class Trainer:
inner layers, dropout probabilities etc). inner layers, dropout probabilities etc).
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics.
callbacks (List of [`TrainerCallback`], *optional*): callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback). detailed in [here](callback).
...@@ -382,6 +385,13 @@ class Trainer: ...@@ -382,6 +385,13 @@ class Trainer:
output_dir = "tmp_trainer" output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
args = TrainingArguments(output_dir=output_dir) args = TrainingArguments(output_dir=output_dir)
if args.batch_eval_metrics and compute_metrics is not None:
if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
raise ValueError(
"When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
" boolean argument which will be triggered after the last batch of the eval set to signal that the"
" summary statistics should be returned by the function."
)
self.args = args self.args = args
# Seed must be set before instantiating the model when using model # Seed must be set before instantiating the model when using model
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
...@@ -3205,6 +3215,9 @@ class Trainer: ...@@ -3205,6 +3215,9 @@ class Trainer:
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
del inputs
torch.cuda.empty_cache()
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
...@@ -3703,6 +3716,8 @@ class Trainer: ...@@ -3703,6 +3716,8 @@ class Trainer:
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
metrics = None
# Will be useful when we have an iterable dataset so don't know its length. # Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0 observed_num_examples = 0
...@@ -3731,27 +3746,50 @@ class Trainer: ...@@ -3731,27 +3746,50 @@ class Trainer:
if inputs_decode is not None: if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function((inputs_decode)) inputs_decode = self.gather_function((inputs_decode))
all_inputs.add(inputs_decode) if not self.args.batch_eval_metrics or description == "Prediction":
all_inputs.add(inputs_decode)
if logits is not None: if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None: if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function((logits)) logits = self.gather_function((logits))
all_preds.add(logits) if not self.args.batch_eval_metrics or description == "Prediction":
all_preds.add(logits)
if labels is not None: if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels)) labels = self.gather_function((labels))
all_labels.add(labels) if not self.args.batch_eval_metrics or description == "Prediction":
all_labels.add(labels)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and logits is not None and labels is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels),
compute_result=is_last_step,
)
del losses, logits, labels, inputs
torch.cuda.empty_cache()
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
all_losses.to_cpu_and_numpy() all_losses.to_cpu_and_numpy()
all_preds.to_cpu_and_numpy() all_preds.to_cpu_and_numpy()
all_labels.to_cpu_and_numpy() all_labels.to_cpu_and_numpy()
all_inputs.to_cpu_and_numpy() all_inputs.to_cpu_and_numpy()
del losses, logits, labels, inputs
torch.cuda.empty_cache()
# After all calls to `.gather_function`, reset to `gather_for_metrics`: # After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics self.gather_function = self.accelerator.gather_for_metrics
if args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
...@@ -3780,14 +3818,19 @@ class Trainer: ...@@ -3780,14 +3818,19 @@ class Trainer:
num_samples = observed_num_examples num_samples = observed_num_examples
# Metrics! # Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if (
self.compute_metrics is not None
and all_preds is not None
and all_labels is not None
and not self.args.batch_eval_metrics
):
if args.include_inputs_for_metrics: if args.include_inputs_for_metrics:
metrics = self.compute_metrics( metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
) )
else: else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else: elif metrics is None:
metrics = {} metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors # To be JSON-serializable, we need to remove numpy types or zero-d tensors
...@@ -4243,6 +4286,7 @@ class Trainer: ...@@ -4243,6 +4286,7 @@ class Trainer:
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
metrics: Optional[dict] = None
world_size = max(1, args.world_size) world_size = max(1, args.world_size)
...@@ -4284,8 +4328,24 @@ class Trainer: ...@@ -4284,8 +4328,24 @@ class Trainer:
) )
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if self.args.batch_eval_metrics:
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host),
compute_result=is_last_step,
)
if self.args.batch_eval_metrics or (
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
):
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only: if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
...@@ -4293,6 +4353,8 @@ class Trainer: ...@@ -4293,6 +4353,8 @@ class Trainer:
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
# Set back to None to begin a new accumulation # Set back to None to begin a new accumulation
del losses_host, preds_host, labels_host, inputs_host
torch.cuda.empty_cache()
losses_host, preds_host, labels_host, inputs_host = None, None, None, None losses_host, preds_host, labels_host, inputs_host = None, None, None, None
if args.past_index and hasattr(self, "_past"): if args.past_index and hasattr(self, "_past"):
...@@ -4311,14 +4373,19 @@ class Trainer: ...@@ -4311,14 +4373,19 @@ class Trainer:
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
if self.compute_metrics is not None and preds is not None and label_ids is not None: if (
self.compute_metrics is not None
and preds is not None
and label_ids is not None
and not self.args.batch_eval_metrics
):
if args.include_inputs_for_metrics: if args.include_inputs_for_metrics:
metrics = self.compute_metrics( metrics = self.compute_metrics(
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
) )
else: else:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else: elif metrics is None:
metrics = {} metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors # To be JSON-serializable, we need to remove numpy types or zero-d tensors
......
...@@ -756,6 +756,12 @@ class TrainingArguments: ...@@ -756,6 +756,12 @@ class TrainingArguments:
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
only. only.
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
""" """
framework = "pt" framework = "pt"
...@@ -1434,6 +1440,11 @@ class TrainingArguments: ...@@ -1434,6 +1440,11 @@ class TrainingArguments:
}, },
) )
batch_eval_metrics: bool = field(
default=False,
metadata={"help": "Break eval metrics calculation into batches to save memory."},
)
def __post_init__(self): def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string # Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS: for field in _VALID_DICT_FIELDS:
......
...@@ -230,6 +230,27 @@ class AlmostAccuracy: ...@@ -230,6 +230,27 @@ class AlmostAccuracy:
return {"accuracy": true.astype(np.float32).mean().item()} return {"accuracy": true.astype(np.float32).mean().item()}
class AlmostAccuracyBatched:
def __init__(self, thresh=0.25):
self.thresh = thresh
self.batch_acc = []
def __call__(self, eval_pred, compute_result):
predictions, labels = eval_pred
if isinstance(predictions, tuple):
predictions = predictions[0]
if isinstance(labels, tuple):
labels = labels[0]
batch_size = len(predictions)
true = torch.abs(predictions - labels) <= self.thresh
acc = true.type(torch.FloatTensor).mean().item()
self.batch_acc.extend([acc] * batch_size)
if compute_result:
result = {"accuracy": np.mean(self.batch_acc).item()}
self.batch_acc = []
return result
class RegressionModelConfig(PretrainedConfig): class RegressionModelConfig(PretrainedConfig):
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs): def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -1524,6 +1545,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1524,6 +1545,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc) self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_evaluate_with_batch_eval_metrics(self):
trainer = get_regression_trainer(
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With logits preprocess
trainer = get_regression_trainer(
a=1.5,
b=2.5,
compute_metrics=AlmostAccuracyBatched(),
batch_eval_metrics=True,
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_evaluate_with_jit(self): def test_evaluate_with_jit(self):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True) trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
results = trainer.evaluate() results = trainer.evaluate()
...@@ -1651,6 +1715,58 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1651,6 +1715,58 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
def test_predict_with_batch_eval_metrics(self):
trainer = get_regression_trainer(
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
)
results = trainer.predict(trainer.eval_dataset)
preds = results.predictions
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
gt = 1.5 * x + 2.5
self.assertTrue(np.allclose(preds, gt))
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
)
results = trainer.predict(trainer.eval_dataset)
preds = results.predictions
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
# With more than one output of the model
trainer = get_regression_trainer(
a=1.5, b=2.5, double_output=True, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
# With more than one output/label of the model
trainer = get_regression_trainer(
a=1.5,
b=2.5,
double_output=True,
label_names=["labels", "labels_2"],
compute_metrics=AlmostAccuracyBatched(),
batch_eval_metrics=True,
)
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
def test_predict_with_jit(self): def test_predict_with_jit(self):
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True) trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
preds = trainer.predict(trainer.eval_dataset).predictions preds = trainer.predict(trainer.eval_dataset).predictions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment