Unverified Commit a911b856 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Resolve conflicts for #4760 (#4762)

parent 14d2966b
...@@ -170,26 +170,92 @@ class PBTClassArgsValidator(ClassArgsValidator): ...@@ -170,26 +170,92 @@ class PBTClassArgsValidator(ClassArgsValidator):
}).validate(kwargs) }).validate(kwargs)
class PBTTuner(Tuner): class PBTTuner(Tuner):
"""
Population Based Training (PBT) comes from `Population Based Training of Neural Networks <https://arxiv.org/abs/1711.09846v1>`__.
It's a simple asynchronous optimization algorithm which effectively utilizes a fixed computational budget to jointly optimize
a population of models and their hyperparameters to maximize performance.
Importantly, PBT discovers a schedule of hyperparameter settings rather than following the generally sub-optimal strategy of
trying to find a single fixed set to use for the whole course of training.
.. image:: ../../img/pbt.jpg
PBTTuner initializes a population with several trials (i.e., ``population_size``).
There are four steps in the above figure, each trial only runs by one step. How long is one step is controlled by trial code,
e.g., one epoch. When a trial starts, it loads a checkpoint specified by PBTTuner and continues to run one step,
then saves checkpoint to a directory specified by PBTTuner and exits.
The trials in a population run steps synchronously, that is, after all the trials finish the ``i``-th step,
the ``(i+1)``-th step can be started. Exploitation and exploration of PBT are executed between two consecutive steps.
Two important steps to follow if you are trying to use PBTTuner:
1. **Provide checkpoint directory**. Since some trials need to load other trial's checkpoint,
users should provide a directory (i.e., ``all_checkpoint_dir``) which is accessible by every trial.
It is easy for local mode, users could directly use the default directory or specify any directory on the local machine.
For other training services, users should follow
:doc:`the document of those training services </experiment/training_service/shared_storage>`
to provide a directory in a shared storage, such as NFS, Azure storage.
2. **Modify your trial code**. Before running a step, a trial needs to load a checkpoint,
the checkpoint directory is specified in hyper-parameter configuration generated by PBTTuner,
i.e., ``params['load_checkpoint_dir']``. Similarly, the directory for saving checkpoint is also included in the configuration,
i.e., ``params['save_checkpoint_dir']``. Here, ``all_checkpoint_dir`` is base folder of ``load_checkpoint_dir``
and ``save_checkpoint_dir`` whose format is ``all_checkpoint_dir/<population-id>/<step>``.
.. code-block:: python
params = nni.get_next_parameter()
# the path of the checkpoint to load
load_path = os.path.join(params['load_checkpoint_dir'], 'model.pth')
# load checkpoint from `load_path`
...
# run one step
...
# the path for saving a checkpoint
save_path = os.path.join(params['save_checkpoint_dir'], 'model.pth')
# save checkpoint to `save_path`
...
The complete example code can be found :githublink:`here <examples/trials/mnist-pbt-tuner-pytorch>`.
Parameters
----------
optimize_mode : ``maximize`` or ``minimize``, default: ``maximize``
If ``maximize``, the tuner will target to maximize metrics. If ``minimize``, the tuner will target to minimize metrics.
all_checkpoint_dir : str
Directory for trials to load and save checkpoint.
If not specified, the directory would be ``~/nni/checkpoint/``.
Note that if the experiment is not local mode,
users should provide a path in a shared storage which can be accessed by all the trials.
population_size : int, default = 10
Number of trials in a population. Each step has this number of trials.
In our implementation, one step is running each trial by specific training epochs set by users.
factor : float, default = (1.2, 0.8)
Factors for perturbation of hyperparameters.
resample_probability : float, default = 0.25
Probability for resampling.
fraction : float, default = 0.2
Fraction for selecting bottom and top trials.
Examples
--------
Below is an example of PBTTuner configuration in experiment config file.
.. code-block:: yaml
tuner:
name: PBTTuner
classArgs:
optimize_mode: maximize
all_checkpoint_dir: /the/path/to/store/checkpoints
population_size: 10
Notes
-----
Assessor is not allowed if PBTTuner is used.
"""
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2, def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2,
resample_probability=0.25, fraction=0.2): resample_probability=0.25, fraction=0.2):
"""
Initialization
Parameters
----------
optimize_mode : str
maximize or minimize
all_checkpoint_dir : str
directory to store training model checkpoint
population_size : int
number of trials for each epoch
factor : float
factor for perturbation
resample_probability : float
probability for resampling
fraction : float
fraction for selecting bottom and top trials
"""
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
if all_checkpoint_dir is None: if all_checkpoint_dir is None:
all_checkpoint_dir = os.getenv('NNI_CHECKPOINT_DIRECTORY') all_checkpoint_dir = os.getenv('NNI_CHECKPOINT_DIRECTORY')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .ppo_tuner import PPOTuner, PPOClassArgsValidator from .ppo_tuner import PPOTuner, PPOClassArgsValidator
...@@ -306,40 +306,37 @@ class PPOClassArgsValidator(ClassArgsValidator): ...@@ -306,40 +306,37 @@ class PPOClassArgsValidator(ClassArgsValidator):
class PPOTuner(Tuner): class PPOTuner(Tuner):
""" """
PPOTuner, the implementation inherits the main logic of the implementation PPOTuner, the implementation inherits the main logic of the implementation
[ppo2 from openai](https://github.com/openai/baselines/tree/master/baselines/ppo2), and is adapted for NAS scenario. `ppo2 from openai <https://github.com/openai/baselines/tree/master/baselines/ppo2>`__ and is adapted for NAS scenario.
It uses ``lstm`` for its policy network and value network, policy and value share the same network. It uses ``lstm`` for its policy network and value network, policy and value share the same network.
Parameters
----------
optimize_mode : str
maximize or minimize
trials_per_update : int
Number of trials to have for each model update
epochs_per_update : int
Number of epochs to run for each model update
minibatch_size : int
Minibatch size (number of trials) for the update
ent_coef : float
Policy entropy coefficient in the optimization objective
lr : float
Learning rate of the model (lstm network), constant
vf_coef : float
Value function loss coefficient in the optimization objective
max_grad_norm : float
Gradient norm clipping coefficient
gamma : float
Discounting factor
lam : float
Advantage estimation discounting factor (lambda in the paper)
cliprange : float
Cliprange in the PPO algorithm, constant
""" """
def __init__(self, optimize_mode, trials_per_update=20, epochs_per_update=4, minibatch_size=4, def __init__(self, optimize_mode, trials_per_update=20, epochs_per_update=4, minibatch_size=4,
ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, cliprange=0.2): ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, cliprange=0.2):
"""
Initialization, PPO model is not initialized here as search space is not received yet.
Parameters
----------
optimize_mode : str
maximize or minimize
trials_per_update : int
Number of trials to have for each model update
epochs_per_update : int
Number of epochs to run for each model update
minibatch_size : int
Minibatch size (number of trials) for the update
ent_coef : float
Policy entropy coefficient in the optimization objective
lr : float
Learning rate of the model (lstm network), constant
vf_coef : float
Value function loss coefficient in the optimization objective
max_grad_norm : float
Gradient norm clipping coefficient
gamma : float
Discounting factor
lam : float
Advantage estimation discounting factor (lambda in the paper)
cliprange : float
Cliprange in the PPO algorithm, constant
"""
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
self.model_config = ModelConfig() self.model_config = ModelConfig()
self.model = None self.model = None
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
Naive random tuner for hyper-parameter optimization. Naive random tuner.
You can specify an integer seed to determine random result. You can specify an integer seed to determine random result.
""" """
__all__ = ['RandomTuner', 'suggest', 'suggest_parameter'] from __future__ import annotations
__all__ = ['RandomTuner']
import logging import logging
...@@ -15,24 +17,46 @@ import numpy as np ...@@ -15,24 +17,46 @@ import numpy as np
import schema import schema
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import format_search_space, deformat_parameters from nni.common.hpo_utils import Deduplicator, format_search_space, deformat_parameters
from nni.tuner import Tuner from nni.tuner import Tuner
_logger = logging.getLogger('nni.tuner.random') _logger = logging.getLogger('nni.tuner.random')
class RandomTuner(Tuner): class RandomTuner(Tuner):
def __init__(self, seed=None): """
A naive tuner that generates fully random hyperparameters.
Examples
--------
.. code-block::
config.tuner.name = 'Random'
config.tuner.class_args = {
'seed': 100
}
Parameters
----------
seed
The random seed.
"""
def __init__(self, seed: int | None = None):
self.space = None self.space = None
if seed is None: # explicitly generate a seed to make the experiment reproducible if seed is None: # explicitly generate a seed to make the experiment reproducible
seed = np.random.default_rng().integers(2 ** 31) seed = np.random.default_rng().integers(2 ** 31)
self.rng = np.random.default_rng(seed) self.rng = np.random.default_rng(seed)
self.dedup = None
_logger.info(f'Using random seed {seed}') _logger.info(f'Using random seed {seed}')
def update_search_space(self, space): def update_search_space(self, space):
self.space = format_search_space(space) self.space = format_search_space(space)
self.dedup = Deduplicator(self.space)
def generate_parameters(self, *args, **kwargs): def generate_parameters(self, *args, **kwargs):
params = suggest(self.rng, self.space) params = suggest(self.rng, self.space)
params = self.dedup(params)
return deformat_parameters(params, self.space) return deformat_parameters(params, self.space)
def receive_trial_result(self, *args, **kwargs): def receive_trial_result(self, *args, **kwargs):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
import logging import logging
import random import random
......
...@@ -38,20 +38,46 @@ class SMACClassArgsValidator(ClassArgsValidator): ...@@ -38,20 +38,46 @@ class SMACClassArgsValidator(ClassArgsValidator):
class SMACTuner(Tuner): class SMACTuner(Tuner):
""" """
This is a wrapper of [SMAC](https://github.com/automl/SMAC3) following NNI tuner interface. `SMAC <https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf>`__ is based on Sequential Model-Based Optimization (SMBO).
It only supports ``SMAC`` mode, and does not support the multiple instances of SMAC3 (i.e., It adapts the most prominent previously used model class (Gaussian stochastic process models)
the same configuration is run multiple times). and introduces the model class of random forests to SMBO in order to handle categorical parameters.
The SMAC supported by nni is a wrapper on `the SMAC3 github repo <https://github.com/automl/SMAC3>`__,
following NNI tuner interface :class:`nni.tuner.Tuner`. For algorithm details of SMAC, please refer to the paper
:footcite:t:`hutter2011sequential`.
Note that SMAC on nni only supports a subset of the types in
:doc:`search space </hpo/search_space>`:
``choice``, ``randint``, ``uniform``, ``loguniform``, and ``quniform``.
Note that SMAC needs additional installation using the following command:
.. code-block:: bash
pip install nni[SMAC]
``swig`` is required for SMAC. for Ubuntu ``swig`` can be installed with ``apt``.
Examples
--------
.. code-block::
config.tuner.name = 'SMAC'
config.tuner.class_args = {
'optimize_mode': 'maximize'
}
Parameters
----------
optimize_mode : str
Optimize mode, 'maximize' or 'minimize', by default 'maximize'
config_dedup : bool
If True, the tuner will not generate a configuration that has been already generated.
If False, a configuration may be generated twice, but it is rare for relatively large search space.
""" """
def __init__(self, optimize_mode="maximize", config_dedup=False): def __init__(self, optimize_mode="maximize", config_dedup=False):
"""
Parameters
----------
optimize_mode : str
Optimize mode, 'maximize' or 'minimize', by default 'maximize'
config_dedup : bool
If True, the tuner will not generate a configuration that has been already generated.
If False, a configuration may be generated twice, but it is rare for relatively large search space.
"""
self.logger = logger self.logger = logger
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
self.total_data = {} self.total_data = {}
......
...@@ -2,26 +2,30 @@ ...@@ -2,26 +2,30 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
Tree-structured Parzen Estimator (TPE) tuner for hyper-parameter optimization. Tree-structured Parzen Estimator (TPE) tuner.
Paper: https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf Paper: https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf
Official code: https://github.com/hyperopt/hyperopt/blob/master/hyperopt/tpe.py Official code: https://github.com/hyperopt/hyperopt/blob/master/hyperopt/tpe.py
This is a slightly modified re-implementation of the algorithm. This is a slightly modified re-implementation of the algorithm.
""" """
__all__ = ['TpeTuner', 'TpeArguments', 'suggest', 'suggest_parameter'] from __future__ import annotations
__all__ = ['TpeTuner', 'TpeArguments']
from collections import defaultdict from collections import defaultdict
import logging import logging
import math import math
from typing import NamedTuple, Optional, Union from typing import Any, NamedTuple
import numpy as np import numpy as np
from scipy.special import erf # pylint: disable=no-name-in-module from scipy.special import erf # pylint: disable=no-name-in-module
from nni.common.hpo_utils import Deduplicator, OptimizeMode, format_search_space, deformat_parameters, format_parameters
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.common.hpo_utils import OptimizeMode, format_search_space, deformat_parameters, format_parameters from nni.typehint import Literal
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
from . import random_tuner from . import random_tuner
...@@ -31,12 +35,13 @@ _logger = logging.getLogger('nni.tuner.tpe') ...@@ -31,12 +35,13 @@ _logger = logging.getLogger('nni.tuner.tpe')
class TpeArguments(NamedTuple): class TpeArguments(NamedTuple):
""" """
These are the hyper-parameters of TPE algorithm itself. Hyperparameters of TPE algorithm itself.
To avoid confusing with trials' hyper-parameters, they are called "arguments" in this code.
To avoid confusing with trials' hyperparameters to be tuned, these are called "arguments" here.
Parameters Parameters
========== ----------
constant_liar_type: 'best' | 'worst' | 'mean' | None (default: 'best') constant_liar_type
TPE algorithm itself does not support parallel tuning. TPE algorithm itself does not support parallel tuning.
This parameter specifies how to optimize for trial_concurrency > 1. This parameter specifies how to optimize for trial_concurrency > 1.
...@@ -44,20 +49,21 @@ class TpeArguments(NamedTuple): ...@@ -44,20 +49,21 @@ class TpeArguments(NamedTuple):
How each liar works is explained in paper's section 6.1. How each liar works is explained in paper's section 6.1.
In general "best" suit for small trial number and "worst" suit for large trial number. In general "best" suit for small trial number and "worst" suit for large trial number.
(:doc:`experiment result </sharings/parallelizing_tpe_search>`)
n_startup_jobs: int (default: 20) n_startup_jobs
The first N hyper-parameters are generated fully randomly for warming up. The first N hyperparameters are generated fully randomly for warming up.
If the search space is large, you can increase this value. If the search space is large, you can increase this value.
Or if max_trial_number is small, you may want to decrease it. Or if max_trial_number is small, you may want to decrease it.
n_ei_candidates: int (default: 24) n_ei_candidates
For each iteration TPE samples EI for N sets of parameters and choose the best one. (loosely speaking) For each iteration TPE samples EI for N sets of parameters and choose the best one. (loosely speaking)
linear_forgetting: int (default: 25) linear_forgetting
TPE will lower the weights of old trials. TPE will lower the weights of old trials.
This controls how many iterations it takes for a trial to start decay. This controls how many iterations it takes for a trial to start decay.
prior_weight: float (default: 1.0) prior_weight
TPE treats user provided search space as prior. TPE treats user provided search space as prior.
When generating new trials, it also incorporates the prior in trial history by transforming the search space to When generating new trials, it also incorporates the prior in trial history by transforming the search space to
one trial configuration (i.e., each parameter of this configuration chooses the mean of its candidate range). one trial configuration (i.e., each parameter of this configuration chooses the mean of its candidate range).
...@@ -66,11 +72,11 @@ class TpeArguments(NamedTuple): ...@@ -66,11 +72,11 @@ class TpeArguments(NamedTuple):
With prior weight 1.0, the search space is treated as one good trial. With prior weight 1.0, the search space is treated as one good trial.
For example, "normal(0, 1)" effectly equals to a trial with x = 0 which has yielded good result. For example, "normal(0, 1)" effectly equals to a trial with x = 0 which has yielded good result.
gamma: float (default: 0.25) gamma
Controls how many trials are considered "good". Controls how many trials are considered "good".
The number is calculated as "min(gamma * sqrt(N), linear_forgetting)". The number is calculated as "min(gamma * sqrt(N), linear_forgetting)".
""" """
constant_liar_type: Optional[str] = 'best' constant_liar_type: Literal['best', 'worst', 'mean'] | None = 'best'
n_startup_jobs: int = 20 n_startup_jobs: int = 20
n_ei_candidates: int = 24 n_ei_candidates: int = 24
linear_forgetting: int = 25 linear_forgetting: int = 25
...@@ -79,24 +85,75 @@ class TpeArguments(NamedTuple): ...@@ -79,24 +85,75 @@ class TpeArguments(NamedTuple):
class TpeTuner(Tuner): class TpeTuner(Tuner):
""" """
Tree-structured Parzen Estimator (TPE) tuner.
TPE is a lightweight tuner that has no extra dependency and supports all search space types,
designed to be the default tuner.
It has the drawback that TPE cannot discover relationship between different hyperparameters.
**Implementation**
TPE is an SMBO algorithm.
It models P(x|y) and P(y) where x represents hyperparameters and y the evaluation result.
P(x|y) is modeled by transforming the generative process of hyperparameters,
replacing the distributions of the configuration prior with non-parametric densities.
Paper: `Algorithms for Hyper-Parameter Optimization
<https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf>`__
Examples
--------
.. code-block::
## minimal config ##
config.tuner.name = 'TPE'
config.tuner.class_args = {
'optimize_mode': 'maximize'
}
.. code-block::
## advanced config ##
config.tuner.name = 'TPE'
config.tuner.class_args = {
'optimize_mode': maximize,
'seed': 12345,
'tpe_args': {
'constant_liar_type': 'mean',
'n_startup_jobs': 10,
'n_ei_candidates': 20,
'linear_forgetting': 100,
'prior_weight': 0,
'gamma': 0.5
}
}
Parameters Parameters
========== ----------
optimze_mode: 'minimize' | 'maximize' (default: 'minimize') optimze_mode: Literal['minimize', 'maximize']
Whether optimize to minimize or maximize trial result. Whether optimize to minimize or maximize trial result.
seed: int | None seed
The random seed. The random seed.
tpe_args: dict[string, Any] | None tpe_args
Advanced users can use this to customize TPE tuner. Advanced users can use this to customize TPE tuner.
See `TpeArguments` for details. See :class:`TpeArguments` for details.
""" """
def __init__(self, optimize_mode='minimize', seed=None, tpe_args=None): def __init__(self,
optimize_mode: Literal['minimize', 'maximize'] = 'minimize',
seed: int | None = None,
tpe_args: dict[str, Any] | None = None):
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
self.args = TpeArguments(**(tpe_args or {})) self.args = TpeArguments(**(tpe_args or {}))
self.space = None self.space = None
# concurrent generate_parameters() calls are likely to yield similar result, because they use same history # concurrent generate_parameters() calls are likely to yield similar result, because they use same history
# the liar solves this problem by adding fake results to history # the liar solves this problem by adding fake results to history
self.liar = create_liar(self.args.constant_liar_type) self.liar = create_liar(self.args.constant_liar_type)
self.dedup = None
if seed is None: # explicitly generate a seed to make the experiment reproducible if seed is None: # explicitly generate a seed to make the experiment reproducible
seed = np.random.default_rng().integers(2 ** 31) seed = np.random.default_rng().integers(2 ** 31)
...@@ -109,6 +166,7 @@ class TpeTuner(Tuner): ...@@ -109,6 +166,7 @@ class TpeTuner(Tuner):
def update_search_space(self, space): def update_search_space(self, space):
self.space = format_search_space(space) self.space = format_search_space(space)
self.dedup = Deduplicator(self.space)
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
if self.liar and self._running_params: if self.liar and self._running_params:
...@@ -122,6 +180,7 @@ class TpeTuner(Tuner): ...@@ -122,6 +180,7 @@ class TpeTuner(Tuner):
history = self._history history = self._history
params = suggest(self.args, self.rng, self.space, history) params = suggest(self.args, self.rng, self.space, history)
params = self.dedup(params)
self._params[parameter_id] = params self._params[parameter_id] = params
self._running_params[parameter_id] = params self._running_params[parameter_id] = params
...@@ -183,7 +242,7 @@ def suggest_parameter(args, rng, spec, parameter_history): ...@@ -183,7 +242,7 @@ def suggest_parameter(args, rng, spec, parameter_history):
## Utilities part ## ## Utilities part ##
class Record(NamedTuple): class Record(NamedTuple):
param: Union[int, float] param: int | float
loss: float loss: float
class BestLiar: # assume running parameters have best result, it accelerates "converging" class BestLiar: # assume running parameters have best result, it accelerates "converging"
...@@ -305,7 +364,7 @@ def adaptive_parzen_normal(args, history_mus, prior_mu, prior_sigma): ...@@ -305,7 +364,7 @@ def adaptive_parzen_normal(args, history_mus, prior_mu, prior_sigma):
this function is used for everything other than "choice" and "randint". this function is used for everything other than "choice" and "randint".
Parameters Parameters
========== ----------
args: TpeArguments args: TpeArguments
Algorithm arguments. Algorithm arguments.
history_mus: 1-d array of float history_mus: 1-d array of float
...@@ -317,7 +376,7 @@ def adaptive_parzen_normal(args, history_mus, prior_mu, prior_sigma): ...@@ -317,7 +376,7 @@ def adaptive_parzen_normal(args, history_mus, prior_mu, prior_sigma):
σ value of normal search space. σ value of normal search space.
Returns Returns
======= -------
Tuple of three 1-d float arrays: (weight, µ, σ). Tuple of three 1-d float arrays: (weight, µ, σ).
The tuple represents N+1 "vicinity of observations" and each one's weight, The tuple represents N+1 "vicinity of observations" and each one's weight,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import ast
import os import os
import timeit import timeit
import torch import torch
...@@ -10,17 +11,20 @@ import torch ...@@ -10,17 +11,20 @@ import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils import count_flops_params
LUT_FILE = "lut.npy" LUT_FILE = "lut.npy"
LUT_JSON_FILE = "lut.txt"
LUT_PATH = "lut" LUT_PATH = "lut"
DATA_TYPE = "float"
class NASConfig: class NASConfig:
def __init__( def __init__(
self, self,
perf_metric="flops", perf_metric="flops",
lut_load=False, lut_load=False,
lut_load_format="json",
model_dir=None, model_dir=None,
nas_lr=0.01, nas_lr=0.01,
nas_weight_decay=5e-4, nas_weight_decay=5e-4,
...@@ -41,6 +45,13 @@ class NASConfig: ...@@ -41,6 +45,13 @@ class NASConfig:
], "perf_metric should be ['flops', 'latency']" ], "perf_metric should be ['flops', 'latency']"
# wether load or create lut file # wether load or create lut file
self.lut_load = lut_load self.lut_load = lut_load
assert lut_load_format in [
"json",
"numpy",
], "lut_load_format should be ['json', 'numpy']"
self.lut_load_format = lut_load_format
# necessary dirs # necessary dirs
self.lut_en = model_dir is not None self.lut_en = model_dir is not None
if self.lut_en: if self.lut_en:
...@@ -252,8 +263,14 @@ class LookUpTable: ...@@ -252,8 +263,14 @@ class LookUpTable:
if config.lut_en: if config.lut_en:
self.lut_perf = None self.lut_perf = None
self.lut_file = os.path.join(config.lut_path, LUT_FILE) self.lut_file = os.path.join(config.lut_path, LUT_FILE)
self.lut_json_file = LUT_JSON_FILE
if config.lut_load: if config.lut_load:
self._load_from_file() if config.lut_load_format == "numpy":
# Load data from numpy file
self._load_from_file()
else:
# Load data from json file
self._load_from_json_file()
else: else:
self._create_perfs() self._create_perfs()
...@@ -349,3 +366,68 @@ class LookUpTable: ...@@ -349,3 +366,68 @@ class LookUpTable:
def _load_from_file(self): def _load_from_file(self):
"""Load numpy file.""" """Load numpy file."""
self.lut_perf = np.load(self.lut_file, allow_pickle=True) self.lut_perf = np.load(self.lut_file, allow_pickle=True)
def _load_from_json_file(self):
"""Load json file."""
"""
lut_json_file ('lut.txt') format:
{'op_name': operator_name,
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
'op_dtype': data_type,
'op_latency': latency}
{...}
{...}
"""
latency_file = open(self.lut_json_file, "r")
ops_latency = latency_file.readlines()
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
ops_lut = {}
for op_latency in ops_latency:
assert isinstance(op_latency, str) or isinstance(op_latency, dict)
if isinstance(op_latency, str):
record = ast.literal_eval(op_latency)
elif isinstance(op_latency, dict):
record = op_latency
op_name = record["op_name"]
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
op_data_shape = record["op_data_shape"]
op_dtype = record["op_dtype"]
op_latency = record["op_latency"]
if op_name not in ops_lut:
ops_lut[op_name] = {}
if op_data_shape not in ops_lut[op_name]:
ops_lut[op_name][op_data_shape] = {}
ops_lut[op_name][op_data_shape][op_dtype] = op_latency
self.lut_perf = [{} for i in range(self.cnt_layers)]
layer_id = 0
for stage_name in self.lut_ops:
stage_ops = self.lut_ops[stage_name]
ops_num = self.layer_num[stage_name]
for _ in range(ops_num):
for op_name in stage_ops:
layer_config = self.layer_configs[layer_id]
layer_in_shape = self.layer_in_shapes[layer_id]
input_w = layer_in_shape[1]
input_h = layer_in_shape[2]
c_in = layer_config[0]
c_out = layer_config[1]
stride = layer_config[2]
op_data_shape = (input_w, input_h, c_in, c_out, stride)
if op_name in ops_lut and op_data_shape in ops_lut[op_name]:
self.lut_perf[layer_id][op_name] = \
ops_lut[op_name][op_data_shape][DATA_TYPE]
layer_id += 1
...@@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not. ...@@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not.
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details. See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details.
""" """
from __future__ import annotations
from enum import Enum from enum import Enum
import logging import logging
from .recoverable import Recoverable from .recoverable import Recoverable
from .typehint import TrialMetric
__all__ = ['AssessResult', 'Assessor'] __all__ = ['AssessResult', 'Assessor']
...@@ -54,7 +57,7 @@ class Assessor(Recoverable): ...@@ -54,7 +57,7 @@ class Assessor(Recoverable):
:class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor` :class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor`
""" """
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id: str, trial_history: list[TrialMetric]) -> AssessResult:
""" """
Abstract method for determining whether a trial should be killed. Must override. Abstract method for determining whether a trial should be killed. Must override.
...@@ -91,7 +94,7 @@ class Assessor(Recoverable): ...@@ -91,7 +94,7 @@ class Assessor(Recoverable):
""" """
raise NotImplementedError('Assessor: assess_trial not implemented') raise NotImplementedError('Assessor: assess_trial not implemented')
def trial_end(self, trial_job_id, success): def trial_end(self, trial_job_id: str, success: bool) -> None:
""" """
Abstract method invoked when a trial is completed or terminated. Do nothing by default. Abstract method invoked when a trial is completed or terminated. Do nothing by default.
...@@ -103,22 +106,22 @@ class Assessor(Recoverable): ...@@ -103,22 +106,22 @@ class Assessor(Recoverable):
True if the trial successfully completed; False if failed or terminated. True if the trial successfully completed; False if failed or terminated.
""" """
def load_checkpoint(self): def load_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) _logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self): def save_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) _logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def _on_exit(self): def _on_exit(self) -> None:
pass pass
def _on_error(self): def _on_error(self) -> None:
pass pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .serializer import trace, dump, load, is_traceable from .serializer import trace, dump, load, is_traceable
...@@ -47,7 +47,7 @@ class TorchGraph: ...@@ -47,7 +47,7 @@ class TorchGraph:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model user wants to speed up The model user wants to speedup
dummy_input : pytorch tensor dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule traced_model : torch._C.torch.jit.TopLevelTracedModule
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .dedup import Deduplicator
from .validation import validate_search_space from .validation import validate_search_space
from .formatting import * from .formatting import *
from .optimize_mode import OptimizeMode from .optimize_mode import OptimizeMode
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Deduplicate repeated parameters.
No guarantee for forward-compatibility.
"""
from __future__ import annotations
import logging
import typing
import nni
from .formatting import FormattedParameters, FormattedSearchSpace, ParameterSpec, deformat_parameters
# TODO:
# Move main logic of basic tuners (random and grid search) into SDK,
# so we can get rid of private methods and circular dependency.
if typing.TYPE_CHECKING:
from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner
_logger = logging.getLogger(__name__)
class Deduplicator:
"""
A helper for tuners to deduplicate generated parameters.
When the tuner generates an already existing parameter,
calling this will return a new parameter generated with grid search.
Otherwise it returns the orignial parameter object.
If all parameters have been generated, raise ``NoMoreTrialError``.
All search space types, including nested choice, are supported.
Resuming and updating search space are not supported for now.
It will not raise error, but may return duplicate parameters.
See random tuner's source code for example usage.
"""
def __init__(self, formatted_search_space: FormattedSearchSpace):
self._space: FormattedSearchSpace = formatted_search_space
self._never_dup: bool = any(_spec_never_dup(spec) for spec in self._space.values())
self._history: set[str] = set()
self._grid_search: GridSearchTuner | None = None
def __call__(self, formatted_parameters: FormattedParameters) -> FormattedParameters:
if self._never_dup or self._not_dup(formatted_parameters):
return formatted_parameters
if self._grid_search is None:
_logger.info(f'Tuning algorithm generated duplicate parameter: {formatted_parameters}')
_logger.info(f'Use grid search for deduplication.')
self._init_grid_search()
while True:
new = self._grid_search._suggest() # type: ignore
if new is None:
raise nni.NoMoreTrialError()
if self._not_dup(new):
return new
def _init_grid_search(self) -> None:
from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner
self._grid_search = GridSearchTuner()
self._grid_search.history = self._history
self._grid_search.space = self._space
self._grid_search._init_grid()
def _not_dup(self, formatted_parameters: FormattedParameters) -> bool:
params = deformat_parameters(formatted_parameters, self._space)
params_str = typing.cast(str, nni.dump(params, sort_keys=True))
if params_str in self._history:
return False
else:
self._history.add(params_str)
return True
def _spec_never_dup(spec: ParameterSpec) -> bool:
if spec.is_nested():
return False # "not chosen" duplicates with "not chosen"
if spec.categorical or spec.q is not None:
return False
if spec.normal_distributed:
return spec.sigma > 0
else:
return spec.low < spec.high
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
Helper class and functions for tuners to deal with search space.
This script provides a more program-friendly representation of HPO search space. This script provides a more program-friendly representation of HPO search space.
The format is considered internal helper and is not visible to end users. The format is considered internal helper and is not visible to end users.
...@@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space. ...@@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space.
The random tuner is an intuitive example for this utility. The random tuner is an intuitive example for this utility.
You should check its code before reading docstrings in this file. You should check its code before reading docstrings in this file.
.. attention::
This module does not guarantee forward-compatibility.
If you want to use it outside official NNI repo, it is recommended to copy the script.
""" """
from __future__ import annotations
__all__ = [ __all__ = [
'ParameterSpec', 'ParameterSpec',
'deformat_parameters', 'deformat_parameters',
...@@ -20,10 +30,16 @@ __all__ = [ ...@@ -20,10 +30,16 @@ __all__ = [
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, Dict, NamedTuple, Tuple, cast
import numpy as np import numpy as np
from nni.typehint import Parameters, SearchSpace
ParameterKey = Tuple['str | int', ...]
FormattedParameters = Dict[ParameterKey, 'float | int']
FormattedSearchSpace = Dict[ParameterKey, 'ParameterSpec']
class ParameterSpec(NamedTuple): class ParameterSpec(NamedTuple):
""" """
Specification (aka space / range / domain) of one single parameter. Specification (aka space / range / domain) of one single parameter.
...@@ -33,52 +49,59 @@ class ParameterSpec(NamedTuple): ...@@ -33,52 +49,59 @@ class ParameterSpec(NamedTuple):
name: str # The object key in JSON name: str # The object key in JSON
type: str # "_type" in JSON type: str # "_type" in JSON
values: List[Any] # "_value" in JSON values: list[Any] # "_value" in JSON
key: Tuple[str] # The "path" of this parameter key: ParameterKey # The "path" of this parameter
categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered) categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered)
size: int = None # If it's categorical, how many candidates it has size: int = cast(int, None) # If it's categorical, how many candidates it has
chosen_size: Optional[int] = 1 # If it's categorical, it should choose how many candidates. chosen_size: int | None = 1 # If it's categorical, it should choose how many candidates.
# By default, 1. If none, arbitrary number of candidates can be chosen. # By default, 1. If none, arbitrary number of candidates can be chosen.
# uniform distributed # uniform distributed
low: float = None # Lower bound of uniform parameter low: float = cast(float, None) # Lower bound of uniform parameter
high: float = None # Upper bound of uniform parameter high: float = cast(float, None) # Upper bound of uniform parameter
normal_distributed: bool = None # Whether this parameter is uniform or normal distrubuted normal_distributed: bool = cast(bool, None)
mu: float = None # µ of normal parameter # Whether this parameter is uniform or normal distrubuted
sigma: float = None # σ of normal parameter mu: float = cast(float, None) # µ of normal parameter
sigma: float = cast(float, None)# σ of normal parameter
q: Optional[float] = None # If not `None`, the parameter value should be an integer multiple of this q: float | None = None # If not `None`, the parameter value should be an integer multiple of this
clip: Optional[Tuple[float, float]] = None clip: tuple[float, float] | None = None
# For q(log)uniform, this equals to "values[:2]"; for others this is None # For q(log)uniform, this equals to "values[:2]"; for others this is None
log_distributed: bool = None # Whether this parameter is log distributed log_distributed: bool = cast(bool, None)
# Whether this parameter is log distributed
# When true, low/high/mu/sigma describes log of parameter value (like np.lognormal) # When true, low/high/mu/sigma describes log of parameter value (like np.lognormal)
def is_activated_in(self, partial_parameters): def is_activated_in(self, partial_parameters: FormattedParameters) -> bool:
""" """
For nested search space, check whether this parameter should be skipped for current set of paremters. For nested search space, check whether this parameter should be skipped for current set of paremters.
This function must be used in a pattern similar to random tuner. Otherwise it will misbehave. This function must be used in a pattern similar to random tuner. Otherwise it will misbehave.
""" """
if len(self.key) < 2 or isinstance(self.key[-2], str): if self.is_nested():
return partial_parameters[self.key[:-2]] == self.key[-2]
else:
return True return True
return partial_parameters[self.key[:-2]] == self.key[-2]
def format_search_space(search_space): def is_nested(self):
"""
Check whether this parameter is inside a nested choice.
"""
return len(self.key) >= 2 and isinstance(self.key[-2], int)
def format_search_space(search_space: SearchSpace) -> FormattedSearchSpace:
""" """
Convert user provided search space into a dict of ParameterSpec. Convert user provided search space into a dict of ParameterSpec.
The dict key is dict value's `ParameterSpec.key`. The dict key is dict value's `ParameterSpec.key`.
""" """
formatted = _format_search_space(tuple(), search_space) formatted = _format_search_space(tuple(), search_space)
# In CPython 3.6, dicts preserve order by internal implementation.
# In Python 3.7+, dicts preserve order by language spec.
# Python 3.6 is crappy enough. Don't bother to do extra work for it.
# Remove these comments when we drop 3.6 support.
return {spec.key: spec for spec in formatted} return {spec.key: spec for spec in formatted}
def deformat_parameters(formatted_parameters, formatted_search_space): def deformat_parameters(
formatted_parameters: FormattedParameters,
formatted_search_space: FormattedSearchSpace) -> Parameters:
""" """
Convert internal format parameters to users' expected format. Convert internal format parameters to users' expected format.
...@@ -90,10 +113,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space): ...@@ -90,10 +113,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
3. For "q*", convert x to `round(x / q) * q`, then clip into range. 3. For "q*", convert x to `round(x / q) * q`, then clip into range.
4. For nested choices, convert flatten key-value pairs into nested structure. 4. For nested choices, convert flatten key-value pairs into nested structure.
""" """
ret = {} ret: Parameters = {}
for key, x in formatted_parameters.items(): for key, x in formatted_parameters.items():
spec = formatted_search_space[key] spec = formatted_search_space[key]
if spec.categorical: if spec.categorical:
x = cast(int, x)
if spec.type == 'randint': if spec.type == 'randint':
lower = min(math.ceil(float(x)) for x in spec.values) lower = min(math.ceil(float(x)) for x in spec.values)
_assign(ret, key, int(lower + x)) _assign(ret, key, int(lower + x))
...@@ -114,7 +138,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space): ...@@ -114,7 +138,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
_assign(ret, key, x) _assign(ret, key, x)
return ret return ret
def format_parameters(parameters, formatted_search_space): def format_parameters(parameters: Parameters, formatted_search_space: FormattedSearchSpace) -> FormattedParameters:
""" """
Convert end users' parameter format back to internal format, mainly for resuming experiments. Convert end users' parameter format back to internal format, mainly for resuming experiments.
...@@ -125,7 +149,7 @@ def format_parameters(parameters, formatted_search_space): ...@@ -125,7 +149,7 @@ def format_parameters(parameters, formatted_search_space):
for key, spec in formatted_search_space.items(): for key, spec in formatted_search_space.items():
if not spec.is_activated_in(ret): if not spec.is_activated_in(ret):
continue continue
value = parameters value: Any = parameters
for name in key: for name in key:
if isinstance(name, str): if isinstance(name, str):
value = value[name] value = value[name]
...@@ -144,8 +168,8 @@ def format_parameters(parameters, formatted_search_space): ...@@ -144,8 +168,8 @@ def format_parameters(parameters, formatted_search_space):
ret[key] = value ret[key] = value
return ret return ret
def _format_search_space(parent_key, space): def _format_search_space(parent_key: ParameterKey, space: SearchSpace) -> list[ParameterSpec]:
formatted = [] formatted: list[ParameterSpec] = []
for name, spec in space.items(): for name, spec in space.items():
if name == '_name': if name == '_name':
continue continue
...@@ -157,7 +181,7 @@ def _format_search_space(parent_key, space): ...@@ -157,7 +181,7 @@ def _format_search_space(parent_key, space):
formatted += _format_search_space(key, sub_space) formatted += _format_search_space(key, sub_space)
return formatted return formatted
def _format_parameter(key, type_, values): def _format_parameter(key: ParameterKey, type_: str, values: list[Any]):
spec = SimpleNamespace( spec = SimpleNamespace(
name = key[-1], name = key[-1],
type = type_, type = type_,
...@@ -199,7 +223,7 @@ def _format_parameter(key, type_, values): ...@@ -199,7 +223,7 @@ def _format_parameter(key, type_, values):
return ParameterSpec(**spec.__dict__) return ParameterSpec(**spec.__dict__)
def _is_nested_choices(values): def _is_nested_choices(values: list[Any]) -> bool:
assert values # choices should not be empty assert values # choices should not be empty
for value in values: for value in values:
if not isinstance(value, dict): if not isinstance(value, dict):
...@@ -208,9 +232,9 @@ def _is_nested_choices(values): ...@@ -208,9 +232,9 @@ def _is_nested_choices(values):
return False return False
return True return True
def _assign(params, key, x): def _assign(params: Parameters, key: ParameterKey, x: Any) -> None:
if len(key) == 1: if len(key) == 1:
params[key[0]] = x params[cast(str, key[0])] = x
elif isinstance(key[0], int): elif isinstance(key[0], int):
_assign(params, key[1:], x) _assign(params, key[1:], x)
else: else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum from enum import Enum
class OptimizeMode(Enum): class OptimizeMode(Enum):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
from typing import Any, List, Optional from typing import Any
common_search_space_types = [ common_search_space_types = [
'choice', 'choice',
...@@ -19,7 +21,7 @@ common_search_space_types = [ ...@@ -19,7 +21,7 @@ common_search_space_types = [
def validate_search_space( def validate_search_space(
search_space: Any, search_space: Any,
support_types: Optional[List[str]] = None, support_types: list[str] | None = None,
raise_exception: bool = False # for now, in case false positive raise_exception: bool = False # for now, in case false positive
) -> bool: ) -> bool:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc import abc
import base64 import base64
import collections.abc import collections.abc
...@@ -5,6 +8,8 @@ import copy ...@@ -5,6 +8,8 @@ import copy
import functools import functools
import inspect import inspect
import numbers import numbers
import os
import sys
import types import types
import warnings import warnings
from io import IOBase from io import IOBase
...@@ -13,7 +18,7 @@ from typing import Any, Dict, List, Optional, TypeVar, Union ...@@ -13,7 +18,7 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
__all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable'] __all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable', 'is_wrapped_with_trace']
T = TypeVar('T') T = TypeVar('T')
...@@ -23,46 +28,49 @@ class PayloadTooLarge(Exception): ...@@ -23,46 +28,49 @@ class PayloadTooLarge(Exception):
pass pass
class Traceable(abc.ABC): class Traceable:
""" """
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations. A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization. Dict returns a TraceDictType to enable serialization.
""" """
@abc.abstractmethod
def trace_copy(self) -> 'Traceable': def trace_copy(self) -> 'Traceable':
""" """
Perform a shallow copy. Perform a shallow copy.
NOTE: NONE of the attributes will be preserved. NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object. This is the one that should be used when you want to "mutate" a serializable object.
""" """
... raise NotImplementedError()
@property @property
@abc.abstractmethod
def trace_symbol(self) -> Any: def trace_symbol(self) -> Any:
""" """
Symbol object. Could be a class or a function. Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to ``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol. convert the symbol into a string and convert the string back to symbol.
""" """
... raise NotImplementedError()
@property @property
@abc.abstractmethod
def trace_args(self) -> List[Any]: def trace_args(self) -> List[Any]:
""" """
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true, List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments. in which case all the positional arguments are converted into keyword arguments.
""" """
... raise NotImplementedError()
@property @property
@abc.abstractmethod
def trace_kwargs(self) -> Dict[str, Any]: def trace_kwargs(self) -> Dict[str, Any]:
""" """
Dict of keyword arguments. Dict of keyword arguments.
""" """
... raise NotImplementedError()
def get(self) -> Any:
"""
Get the original object. Usually used together with ``trace_copy``.
"""
raise NotImplementedError()
class Translatable(abc.ABC): class Translatable(abc.ABC):
...@@ -84,13 +92,27 @@ class Translatable(abc.ABC): ...@@ -84,13 +92,27 @@ class Translatable(abc.ABC):
def is_traceable(obj: Any) -> bool: def is_traceable(obj: Any) -> bool:
""" """
Check whether an object is a traceable instance (not type). Check whether an object is a traceable instance or type.
Note that an object is traceable only means that it implements the "Traceable" interface,
and the properties have been implemented. It doesn't necessary mean that its type is wrapped with trace,
because the properties could be added **after** the instance has been created.
""" """
return hasattr(obj, 'trace_copy') and \ return hasattr(obj, 'trace_copy') and \
hasattr(obj, 'trace_symbol') and \ hasattr(obj, 'trace_symbol') and \
hasattr(obj, 'trace_args') and \ hasattr(obj, 'trace_args') and \
hasattr(obj, 'trace_kwargs') and \ hasattr(obj, 'trace_kwargs')
not inspect.isclass(obj)
def is_wrapped_with_trace(cls_or_func: Any) -> bool:
"""
Check whether a function or class is already wrapped with ``@nni.trace``.
If a class or function is already wrapped with trace, then the created object must be "traceable".
"""
return getattr(cls_or_func, '_traced', False) and (
not hasattr(cls_or_func, '__dict__') or # in case it's a function
'_traced' in cls_or_func.__dict__ # must be in this class, super-class traced doesn't count
)
class SerializableObject(Traceable): class SerializableObject(Traceable):
...@@ -120,6 +142,13 @@ class SerializableObject(Traceable): ...@@ -120,6 +142,13 @@ class SerializableObject(Traceable):
{k: copy.copy(v) for k, v in self.trace_kwargs.items()}, {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def get(self) -> T:
if not self._get_nni_attr('call_super'):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self
@property @property
def trace_symbol(self) -> Any: def trace_symbol(self) -> Any:
return self._get_nni_attr('symbol') return self._get_nni_attr('symbol')
...@@ -160,6 +189,15 @@ class SerializableObject(Traceable): ...@@ -160,6 +189,15 @@ class SerializableObject(Traceable):
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any: def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object. # If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
# Make an already exist class traceable, without creating a new class.
# Should be used together with `inject_trace_info`.
def getter_factory(x): def getter_factory(x):
return lambda self: self.__dict__['_nni_' + x] return lambda self: self.__dict__['_nni_' + x]
...@@ -177,27 +215,29 @@ def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, An ...@@ -177,27 +215,29 @@ def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, An
{k: copy.copy(v) for k, v in self.trace_kwargs.items()}, {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def get(self):
return self
attributes = { attributes = {
'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')), 'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
'trace_args': property(getter_factory('args'), setter_factory('args')), 'trace_args': property(getter_factory('args'), setter_factory('args')),
'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')), 'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
'trace_copy': trace_copy 'trace_copy': trace_copy,
'get': get,
} }
if hasattr(obj, '__class__') and hasattr(obj, '__dict__'): if not create_wrapper:
for name, method in attributes.items(): for name, method in attributes.items():
setattr(obj.__class__, name, method) setattr(cls, name, method)
return cls
else: else:
wrapper = type('wrapper', (Traceable, type(obj)), attributes) # sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
obj = wrapper(obj) # pylint: disable=abstract-class-instantiated # but I don't want to check here because it's unreliable.
wrapper = type('wrapper', (Traceable, cls), attributes)
# make obj complying with the interface of traceable, though we cannot change its base class return wrapper
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]: def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]:
""" """
Annotate a function or a class if you want to preserve where it comes from. Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios: This is usually used in the following scenarios:
...@@ -221,6 +261,9 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable] ...@@ -221,6 +261,9 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
list and types. This can be useful to extract semantics, but can be tricky in some corner cases. list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept. Therefore, in some cases, some positional arguments will still be kept.
If ``inheritable`` is true, the trace information from superclass will also be available in subclass.
This however, will make the subclass un-trace-able. Note that this argument has no effect when tracing functions.
.. warning:: .. warning::
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class. Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
...@@ -235,12 +278,19 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable] ...@@ -235,12 +278,19 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
pass pass
""" """
# This is an internal flag to control the behavior of trace.
# Useful in doc build and tests.
# Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls_or_func
def wrap(cls_or_func): def wrap(cls_or_func):
# already annotated, do nothing # already annotated, do nothing
if getattr(cls_or_func, '_traced', False): if is_wrapped_with_trace(cls_or_func):
return cls_or_func return cls_or_func
if isinstance(cls_or_func, type): if isinstance(cls_or_func, type):
cls_or_func = _trace_cls(cls_or_func, kw_only) cls_or_func = _trace_cls(cls_or_func, kw_only, inheritable=inheritable)
elif _is_function(cls_or_func): elif _is_function(cls_or_func):
cls_or_func = _trace_func(cls_or_func, kw_only) cls_or_func = _trace_func(cls_or_func, kw_only)
else: else:
...@@ -353,11 +403,60 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme ...@@ -353,11 +403,60 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs) return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
def _trace_cls(base, kw_only, call_super=True): def _trace_cls(base, kw_only, call_super=True, inheritable=False):
# the implementation to trace a class is to store a copy of init arguments # the implementation to trace a class is to store a copy of init arguments
# this won't support class that defines a customized new but should work for most cases # this won't support class that defines a customized new but should work for most cases
class wrapper(SerializableObject, base): if sys.platform != 'linux':
if not call_super:
raise ValueError("'call_super' is mandatory to be set true on non-linux platform")
try:
# In non-linux envs, dynamically creating new classes doesn't work with pickle.
# We have to replace the ``__init__`` with a new ``__init__``.
# This, however, causes side-effects where the replacement is not intended.
# This also doesn't work built-in types (e.g., OrderedDict), and the replacement
# won't be effective any more if ``nni.trace`` is called in-place (e.g., ``nni.trace(nn.Conv2d)(...)``).
original_init = base.__init__
# Makes the new init have the exact same signature as the old one,
# so as to make pytorch-lightning happy.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/4cc05b2cf98e49168a5f5dc265647d75d1d3aae9/pytorch_lightning/utilities/parsing.py#L143
@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
args, kwargs = _formulate_arguments(original_init, args, kwargs, kw_only, is_class_init=True)
original_init(
self,
*[_argument_processor(arg) for arg in args],
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
inject_trace_info(self, base, args, kwargs)
base.__init__ = new_init
base = _make_class_traceable(base)
return base
except TypeError:
warnings.warn("In-place __init__ replacement failed in `@nni.trace`, probably because the type is a built-in/extension type, "
"and it's __init__ can't be replaced. `@nni.trace` is now falling back to the 'inheritance' approach. "
"However, this could cause issues when using pickle. See https://github.com/microsoft/nni/issues/4434",
RuntimeWarning)
# This is trying to solve the case where superclass and subclass are both decorated with @nni.trace.
# We use a metaclass to "unwrap" the superclass.
# However, this doesn't work if:
# 1. Base class already has a customized metaclass. We will raise error in that class.
# 2. SerializableObject in ancester (instead of parent). I think this case is rare and I didn't handle this case yet. FIXME
if type(base) is type and not inheritable:
metaclass = _unwrap_metaclass
else:
metaclass = type
if SerializableObject in inspect.getmro(base):
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
"Please either use the default metaclass, or remove trace from the super-class.")
class wrapper(SerializableObject, base, metaclass=metaclass):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# store a copy of initial parameters # store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True) args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
...@@ -365,6 +464,32 @@ def _trace_cls(base, kw_only, call_super=True): ...@@ -365,6 +464,32 @@ def _trace_cls(base, kw_only, call_super=True):
# calling serializable object init to initialize the full object # calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super) super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super)
def __reduce__(self):
# The issue that decorator and pickler doesn't play well together is well known.
# The workaround solution is to use a fool class (_pickling_object) which pretends to be the pickled object.
# We then put the original type, as well as args and kwargs in its `__new__` argument.
# I suspect that their could still be problems when things get complex,
# e.g., the wrapped class has a custom pickling (`__reduce__``) or `__new__`.
# But it can't be worse because the previous pickle doesn't work at all.
#
# Linked issue: https://github.com/microsoft/nni/issues/4434
# SO: https://stackoverflow.com/questions/52185507/pickle-and-decorated-classes-picklingerror-not-the-same-object
# Store the inner class. The wrapped class couldn't be properly pickled.
type_ = cloudpickle.dumps(type(self).__wrapped__)
# in case they have customized ``__getstate__``.
if hasattr(self, '__getstate__'):
obj_ = self.__getstate__()
else:
obj_ = self.__dict__
# Pickle can't handle type objects.
if '_nni_symbol' in obj_:
obj_['_nni_symbol'] = cloudpickle.dumps(obj_['_nni_symbol'])
return _pickling_object, (type_, kw_only, obj_)
_copy_class_wrapper_attributes(base, wrapper) _copy_class_wrapper_attributes(base, wrapper)
return wrapper return wrapper
...@@ -391,6 +516,8 @@ def _trace_func(func, kw_only): ...@@ -391,6 +516,8 @@ def _trace_func(func, kw_only):
elif hasattr(res, '__class__') and hasattr(res, '__dict__'): elif hasattr(res, '__class__') and hasattr(res, '__dict__'):
# is a class, inject interface directly # is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here. # need to be done before primitive types because there could be inheritance here.
if not getattr(type(res), '_traced', False):
_make_class_traceable(type(res), False) # in-place
res = inject_trace_info(res, func, args, kwargs) res = inject_trace_info(res, func, args, kwargs)
elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)): elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)):
raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.') raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.')
...@@ -400,6 +527,8 @@ def _trace_func(func, kw_only): ...@@ -400,6 +527,8 @@ def _trace_func(func, kw_only):
# will be directly captured by python json encoder # will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload. # and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation. # this is a known limitation.
new_type = _make_class_traceable(type(res), True)
res = new_type(res) # re-creating the object
res = inject_trace_info(res, func, args, kwargs) res = inject_trace_info(res, func, args, kwargs)
else: else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. ' raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
...@@ -425,6 +554,48 @@ def _copy_class_wrapper_attributes(base, wrapper): ...@@ -425,6 +554,48 @@ def _copy_class_wrapper_attributes(base, wrapper):
wrapper.__wrapped__ = base wrapper.__wrapped__ = base
class _unwrap_metaclass(type):
# When a subclass is created, it detects whether the super-class is already annotated with @nni.trace.
# If yes, it gets the ``__wrapped__`` inner class, so that it doesn't inherit SerializableObject twice.
# Note that this doesn't work when metaclass is already defined (such as ABCMeta). We give up in that case.
def __new__(cls, name, bases, dct):
bases = tuple([getattr(base, '__wrapped__', base) for base in bases])
return super().__new__(cls, name, bases, dct)
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
def __subclasscheck__(cls, subclass):
inner_cls = getattr(cls, '__wrapped__', cls)
return inner_cls in inspect.getmro(subclass)
def __instancecheck__(cls, instance):
inner_cls = getattr(cls, '__wrapped__', cls)
return inner_cls in inspect.getmro(type(instance))
class _pickling_object:
# Need `cloudpickle.load` on the callable because the callable is pickled with cloudpickle.
# Used in `_trace_cls`.
def __new__(cls, type_, kw_only, data):
type_ = _wrapped_cloudpickle_loads(type_)
# Restore the trace type
type_ = _trace_cls(type_, kw_only)
# restore type
if '_nni_symbol' in data:
data['_nni_symbol'] = _wrapped_cloudpickle_loads(data['_nni_symbol'])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj = type_.__new__(type_)
if hasattr(obj, '__setstate__'):
obj.__setstate__(data)
else:
obj.__dict__.update(data)
return obj
def _argument_processor(arg): def _argument_processor(arg):
# 1) translate # 1) translate
# handle cases like ValueChoice # handle cases like ValueChoice
...@@ -520,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False): ...@@ -520,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
def _is_function(obj: Any) -> bool: def _is_function(obj: Any) -> bool:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function # https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType, return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType,
types.BuiltinMethodType)) types.BuiltinMethodType)) and obj is not None
def _import_cls_or_func_from_name(target: str) -> Any: def _import_cls_or_func_from_name(target: str) -> Any:
...@@ -533,7 +704,9 @@ def _import_cls_or_func_from_name(target: str) -> Any: ...@@ -533,7 +704,9 @@ def _import_cls_or_func_from_name(target: str) -> Any:
def _strip_trace_type(traceable: Any) -> Any: def _strip_trace_type(traceable: Any) -> Any:
if getattr(traceable, '_traced', False): if getattr(traceable, '_traced', False):
return traceable.__wrapped__ # sometimes, ``__wrapped__`` could be unavailable (e.g., with `inject_trace_info`)
# need to have a default value
return getattr(traceable, '__wrapped__', traceable)
return traceable return traceable
...@@ -571,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) ...@@ -571,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096)
def import_cls_or_func_from_hybrid_name(s: str) -> Any: def import_cls_or_func_from_hybrid_name(s: str) -> Any:
if s.startswith('bytes:'): if s.startswith('bytes:'):
b = base64.b64decode(s.split(':', 1)[-1]) b = base64.b64decode(s.split(':', 1)[-1])
return cloudpickle.loads(b) return _wrapped_cloudpickle_loads(b)
if s.startswith('path:'): if s.startswith('path:'):
s = s.split(':', 1)[-1] s = s.split(':', 1)[-1]
return _import_cls_or_func_from_name(s) return _import_cls_or_func_from_name(s)
...@@ -598,7 +771,7 @@ def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False, ...@@ -598,7 +771,7 @@ def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False,
# Encodes a serializable object instance to json. # Encodes a serializable object instance to json.
# do nothing to instance that is not a serializable object and do not use trace # do nothing to instance that is not a serializable object and do not use trace
if not use_trace or not is_traceable(obj): if not (use_trace and hasattr(obj, '__class__') and is_traceable(type(obj))):
return obj return obj
if isinstance(obj.trace_symbol, property): if isinstance(obj.trace_symbol, property):
...@@ -644,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any: ...@@ -644,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj: if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__'] obj = obj['__nni_obj__']
b = base64.b64decode(obj) b = base64.b64decode(obj)
return cloudpickle.loads(b) return _wrapped_cloudpickle_loads(b)
return obj return obj
def _wrapped_cloudpickle_loads(b: bytes) -> Any:
try:
return cloudpickle.loads(b)
except TypeError:
warnings.warn('TypeError encountered during deserializing object. This could be caused by '
'inconsistency between Python versions where dump and load happens.')
raise
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import sys
import warnings
import cloudpickle
import json_tricks
import numpy
import yaml
import nni
def _minor_version_tuple(version_str: str) -> tuple[int, int]:
# If not a number, returns -1 (e.g., 999.dev0 -> (999, -1))
return tuple(int(x) if x.isdigit() else -1 for x in version_str.split(".")[:2])
PYTHON_VERSION = sys.version_info[:2]
NUMPY_VERSION = _minor_version_tuple(numpy.__version__)
try: try:
import torch import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION = _minor_version_tuple(torch.__version__)
except Exception: except ImportError:
logging.info("PyTorch is not installed.") logging.info("PyTorch is not installed.")
TORCH_VERSION = None TORCH_VERSION = None
try:
import pytorch_lightning
PYTORCH_LIGHTNING_VERSION = _minor_version_tuple(pytorch_lightning.__version__)
except ImportError:
logging.info("PyTorch Lightning is not installed.")
PYTORCH_LIGHTNING_VERSION = None
try:
import tensorflow
TENSORFLOW_VERSION = _minor_version_tuple(tensorflow.__version__)
except ImportError:
logging.info("Tensorflow is not installed.")
TENSORFLOW_VERSION = None
# Serialization version check are needed because they are prone to be inconsistent between versions
CLOUDPICKLE_VERSION = _minor_version_tuple(cloudpickle.__version__)
JSON_TRICKS_VERSION = _minor_version_tuple(json_tricks.__version__)
PYYAML_VERSION = _minor_version_tuple(yaml.__version__)
NNI_VERSION = _minor_version_tuple(nni.__version__)
def version_dump() -> dict[str, tuple[int, int] | None]:
return {
'python': PYTHON_VERSION,
'numpy': NUMPY_VERSION,
'torch': TORCH_VERSION,
'pytorch_lightning': PYTORCH_LIGHTNING_VERSION,
'tensorflow': TENSORFLOW_VERSION,
'cloudpickle': CLOUDPICKLE_VERSION,
'json_tricks': JSON_TRICKS_VERSION,
'pyyaml': PYYAML_VERSION,
'nni': NNI_VERSION
}
def version_check(expect: dict, raise_error: bool = False) -> None:
current_ver = version_dump()
for package in expect:
# version could be list due to serialization
exp_version: tuple | None = tuple(expect[package]) if expect[package] else None
if exp_version is None:
continue
err_message: str | None = None
if package not in current_ver:
err_message = f'{package} is missing in current environment'
elif current_ver[package] != exp_version:
err_message = f'Expect {package} to have version {exp_version}, but {current_ver[package]} found'
if err_message:
if raise_error:
raise RuntimeError('Version check failed: ' + err_message)
else:
warnings.warn('Version check with warning: ' + err_message)
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
from .speedup import ModelSpeedup from .speedup import ModelSpeedup
from .compressor import Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
from .pruning import apply_compression_results from .utils.apply_compression import apply_compression_results
...@@ -631,6 +631,7 @@ class Quantizer(Compressor): ...@@ -631,6 +631,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
wrapper : QuantizerModuleWrapper wrapper : QuantizerModuleWrapper
...@@ -642,6 +643,7 @@ class Quantizer(Compressor): ...@@ -642,6 +643,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize output. quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
output : Tensor output : Tensor
...@@ -655,6 +657,7 @@ class Quantizer(Compressor): ...@@ -655,6 +657,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
...@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
def _quantize(cls, x, scale, zero_point): def _quantize(cls, x, scale, zero_point):
""" """
Reference function for quantizing x -- non-clamped. Reference function for quantizing x -- non-clamped.
Parameters Parameters
---------- ----------
x : Tensor x : Tensor
...@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
scale for quantizing x scale for quantizing x
zero_point : Tensor zero_point : Tensor
zero_point for quantizing x zero_point for quantizing x
Returns Returns
------- -------
tensor tensor
...@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function): ...@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
def get_bits_length(cls, config, quant_type): def get_bits_length(cls, config, quant_type):
""" """
Get bits for quantize config Get bits for quantize config
Parameters Parameters
---------- ----------
config : Dict config : Dict
the configuration for quantization the configuration for quantization
quant_type : str quant_type : str
quant type quant type
Returns Returns
------- -------
int int
...@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor
...@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
quant_min for quantizing tensor quant_min for quantizing tensor
qmax : Tensor qmax : Tensor
quant_max for quantizng tensor quant_max for quantizng tensor
Returns Returns
------- -------
tensor tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment