@@ -528,8 +543,10 @@ class TorchEvaluator(Evaluator):
----------
training_func
The training function is used to train the model, note that this a entire optimization training loop.
It should have three required parameters [model, optimizers, criterion] and three optional parameters [schedulers, max_steps, max_epochs].
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``, it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
It should have three required parameters [model, optimizers, criterion]
and three optional parameters [schedulers, max_steps, max_epochs].
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``,
it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
``criterion`` and ``schedulers`` are also belonging to the ``criterion`` and ``schedulers`` pass to ``TorchEvaluator``.
``max_steps`` and ``max_epochs`` are used to control the training duration.
...
...
@@ -574,7 +591,8 @@ class TorchEvaluator(Evaluator):
Optional. The traced _LRScheduler instance which the lr scheduler class is wrapped by nni.trace.
E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``.
dummy_input
Optional. The dummy_input is used to trace the graph, the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
Optional. The dummy_input is used to trace the graph,
the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
evaluating_func
Optional. A function that input is model and return the evaluation metric.
The return value can be a single float or a tuple (float, Any).
...
...
@@ -634,14 +652,16 @@ class TorchEvaluator(Evaluator):
assertkernel_padding_modein['front','back'],f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
err_msg=f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."