Unverified Commit 3488ef5a authored by shabie's avatar shabie Committed by GitHub
Browse files

[trainer] add option to ignore keys for the train function too (#11719) (#12551)

parent 45dcfdec
......@@ -989,6 +989,7 @@ class Trainer:
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
"""
......@@ -1002,6 +1003,9 @@ class Trainer:
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (:obj:`List[str]`, `optional`)
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
......@@ -1322,13 +1326,13 @@ class Trainer:
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
......@@ -1405,7 +1409,7 @@ class Trainer:
if len(load_result.unexpected_keys) != 0:
logger.warn(f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.")
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
......@@ -1423,7 +1427,7 @@ class Trainer:
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate()
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, epoch, metrics)
if self.control.should_save:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment