Unverified Commit 2d6e2ad4 authored by François Lagunas's avatar François Lagunas Committed by GitHub
Browse files

Adding optional trial argument to model_init (#7759)



* Adding optional trial argument to model_init
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 7e73c128
...@@ -173,6 +173,9 @@ class Trainer: ...@@ -173,6 +173,9 @@ class Trainer:
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`): model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function. :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose
different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values. :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
...@@ -212,15 +215,16 @@ class Trainer: ...@@ -212,15 +215,16 @@ class Trainer:
assert ( assert (
model is not None or model_init is not None model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
if model is None and model_init is not None: if model is None and model_init is not None:
model = model_init() model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None self.model = model.to(args.device) if model is not None else None
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.model_init = model_init
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
...@@ -532,6 +536,17 @@ class Trainer: ...@@ -532,6 +536,17 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
def call_model_init(self, trial=None):
model_init_argcount = len(inspect.signature(self.model_init).parameters)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise Exception("model_init should have 0 or 1 argument.")
return model
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
""" """
Main training entry point. Main training entry point.
...@@ -550,7 +565,9 @@ class Trainer: ...@@ -550,7 +565,9 @@ class Trainer:
if self.model_init is not None: if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init. # Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed) set_seed(self.args.seed)
model = self.model_init()
model = self.call_model_init(trial)
self.model = model.to(self.args.device) self.model = model.to(self.args.device)
# Reinitializes optimizer and scheduler # Reinitializes optimizer and scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment