"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "95ec8e51e40ce21168cb79e08540491049219405"
Unverified Commit d2c610a1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Update guide and reference of NAS (#1972)

parent d8388957
...@@ -19,7 +19,19 @@ class Mutable(nn.Module): ...@@ -19,7 +19,19 @@ class Mutable(nn.Module):
decisions among different mutables. In mutator's implementation, mutators should use the key to decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other. distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global. Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to
produce unique ids.
Parameters
----------
key : str
The key of mutable.
Notes
-----
The counter is program level, but mutables are model level. In case multiple models are defined, and
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually
instead of using automatic keys.
""" """
def __init__(self, key=None): def __init__(self, key=None):
...@@ -51,10 +63,16 @@ class Mutable(nn.Module): ...@@ -51,10 +63,16 @@ class Mutable(nn.Module):
@property @property
def key(self): def key(self):
"""
Read-only property of key.
"""
return self._key return self._key
@property @property
def name(self): def name(self):
"""
After the search space is parsed, it will be the module name of the mutable.
"""
return self._name if hasattr(self, "_name") else "_key" return self._name if hasattr(self, "_name") else "_key"
@name.setter @name.setter
...@@ -75,11 +93,23 @@ class Mutable(nn.Module): ...@@ -75,11 +93,23 @@ class Mutable(nn.Module):
class MutableScope(Mutable): class MutableScope(Mutable):
""" """
Mutable scope marks a subgraph/submodule to help mutators make better decisions. Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will
look like "sub-search-space" in the scope. Scopes can be nested.
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after
the forward method of the class inheriting mutable scope.
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed
to appear in the dict of choices.
Parameters
----------
key : str
Key of mutable scope.
"""
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
...@@ -93,6 +123,31 @@ class MutableScope(Mutable): ...@@ -93,6 +123,31 @@ class MutableScope(Mutable):
class LayerChoice(Mutable): class LayerChoice(Mutable):
"""
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results.
In rare cases, it can also select zero or many.
Layer choice does not allow itself to be nested.
Parameters
----------
op_candidates : list of nn.Module
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum.
``concat`` concatenate the list at dimension 1.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
Attributes
----------
length : int
Number of ops to choose from.
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates) self.length = len(op_candidates)
...@@ -101,6 +156,12 @@ class LayerChoice(Mutable): ...@@ -101,6 +156,12 @@ class LayerChoice(Mutable):
self.return_mask = return_mask self.return_mask = return_mask
def forward(self, *inputs): def forward(self, *inputs):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
...@@ -109,42 +170,62 @@ class LayerChoice(Mutable): ...@@ -109,42 +170,62 @@ class LayerChoice(Mutable):
class InputChoice(Mutable): class InputChoice(Mutable):
""" """
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners, Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to
know about `choose_from`. know about ``choose_from``.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions, The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed. module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed.
"""
NO_KEY = "" In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2,
output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not
"actually" the direct output of cell1.
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, .. code-block:: python
reduction="sum", return_mask=False, key=None):
""" class Cell(MutableScope):
Initialization. pass
class Net(nn.Module):
def __init__(self):
self.cell1 = Cell("cell1")
self.cell2 = Cell("cell2")
self.op = LayerChoice([conv3x3(), conv5x5()], key="op")
self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY])
def forward(self, x):
x1 = max_pooling(self.cell1(x))
x2 = self.cell2(x)
x3 = self.op(x)
x4 = torch.zeros_like(x)
return self.input_choice([x1, x2, x3, x4])
Parameters Parameters
---------- ----------
n_candidates : int n_candidates : int
Number of inputs to choose from. Number of inputs to choose from.
choose_from : list of str choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates``
number of empty string. number of empty string.
n_chosen : int n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any. Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str reduction : str
`mean`, `concat`, `sum` or `none`. ``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`.
return_mask : bool return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only. If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str key : str
Key of the input choice. Key of the input choice.
""" """
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
# precondition check # precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
...@@ -172,12 +253,13 @@ class InputChoice(Mutable): ...@@ -172,12 +253,13 @@ class InputChoice(Mutable):
---------- ----------
optional_inputs : list or dict optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as ``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`. ``choose_from``.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor or torch.Tensor tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
""" """
optional_input_list = optional_inputs optional_input_list = optional_inputs
if isinstance(optional_inputs, dict): if isinstance(optional_inputs, dict):
......
...@@ -43,10 +43,6 @@ class Mutator(BaseMutator): ...@@ -43,10 +43,6 @@ class Mutator(BaseMutator):
""" """
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly. variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
Returns
-------
None
""" """
self._cache = self.sample_search() self._cache = self.sample_search()
...@@ -57,25 +53,28 @@ class Mutator(BaseMutator): ...@@ -57,25 +53,28 @@ class Mutator(BaseMutator):
Returns Returns
------- -------
dict dict
A mapping from key of mutables to decisions.
""" """
return self.sample_final() return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
""" """
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers On default, this method retrieves the decision obtained previously, and select certain operations.
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified Only operations with non-zero weight will be executed. The results will be added to a list.
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
Parameters Parameters
---------- ----------
mutable : LayerChoice mutable : LayerChoice
Layer choice module.
inputs : list of torch.Tensor inputs : list of torch.Tensor
Inputs
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask.
""" """
def _map_fn(op, *inputs): def _map_fn(op, *inputs):
return op(*inputs) return op(*inputs)
...@@ -87,20 +86,20 @@ class Mutator(BaseMutator): ...@@ -87,20 +86,20 @@ class Mutator(BaseMutator):
def on_forward_input_choice(self, mutable, tensor_list): def on_forward_input_choice(self, mutable, tensor_list):
""" """
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` On default, this method retrieves the decision obtained previously, and select certain tensors.
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters Parameters
---------- ----------
mutable : InputChoice mutable : InputChoice
Input choice module.
tensor_list : list of torch.Tensor tensor_list : list of torch.Tensor
tags : list of string Tensor list to apply the decision on.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask.
""" """
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \ assert len(mask) == mutable.n_candidates, \
......
...@@ -6,7 +6,15 @@ from nni.nas.pytorch.mutables import LayerChoice, InputChoice ...@@ -6,7 +6,15 @@ from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class RandomMutator(Mutator): class RandomMutator(Mutator):
"""
Random mutator that samples a random candidate in the search space each time ``reset()``.
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
"""
def sample_search(self): def sample_search(self):
"""
Sample a random candidate.
"""
result = dict() result = dict()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
...@@ -22,4 +30,7 @@ class RandomMutator(Mutator): ...@@ -22,4 +30,7 @@ class RandomMutator(Mutator):
return result return result
def sample_final(self): def sample_final(self):
"""
Same as :meth:`sample_search`.
"""
return self.sample_search() return self.sample_search()
...@@ -16,11 +16,8 @@ _logger = logging.getLogger(__name__) ...@@ -16,11 +16,8 @@ _logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner): class SPOSEvolution(Tuner):
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
""" """
Initialize SPOS Evolution Tuner. SPOS evolution tuner.
Parameters Parameters
---------- ----------
...@@ -39,6 +36,9 @@ class SPOSEvolution(Tuner): ...@@ -39,6 +36,9 @@ class SPOSEvolution(Tuner):
num_mutation : int num_mutation : int
Number of candidates generated by mutation in each epoch. Number of candidates generated by mutation in each epoch.
""" """
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
assert num_population >= num_select assert num_population >= num_select
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.num_select = num_select self.num_select = num_select
......
...@@ -10,9 +10,8 @@ _logger = logging.getLogger(__name__) ...@@ -10,9 +10,8 @@ _logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator): class SPOSSupernetTrainingMutator(RandomMutator):
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
""" """
A random mutator with flops limit.
Parameters Parameters
---------- ----------
...@@ -31,6 +30,9 @@ class SPOSSupernetTrainingMutator(RandomMutator): ...@@ -31,6 +30,9 @@ class SPOSSupernetTrainingMutator(RandomMutator):
flops_sample_timeout : int flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate. Maximum number of attempts to sample before giving up and use a random candidate.
""" """
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
super().__init__(model) super().__init__(model)
self._flops_func = flops_func self._flops_func = flops_func
if self._flops_func is not None: if self._flops_func is not None:
......
...@@ -15,13 +15,7 @@ logger = logging.getLogger(__name__) ...@@ -15,13 +15,7 @@ logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer): class SPOSSupernetTrainer(Trainer):
""" """
This trainer trains a supernet that can be used for evolution search. This trainer trains a supernet that can be used for evolution search.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
"""
Parameters Parameters
---------- ----------
model : nn.Module model : nn.Module
...@@ -52,6 +46,11 @@ class SPOSSupernetTrainer(Trainer): ...@@ -52,6 +46,11 @@ class SPOSSupernetTrainer(Trainer):
callbacks : list of Callback callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks. Callbacks to plug into the trainer. See Callbacks.
""" """
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available() assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model), super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None, loss, metrics, optimizer, num_epochs, None, None,
......
...@@ -24,10 +24,9 @@ class TorchTensorEncoder(json.JSONEncoder): ...@@ -24,10 +24,9 @@ class TorchTensorEncoder(json.JSONEncoder):
class Trainer(BaseTrainer): class Trainer(BaseTrainer):
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
""" """
Trainer initialization. A trainer with some helper functions implemented. To implement a new trainer,
users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
Parameters Parameters
---------- ----------
...@@ -37,14 +36,22 @@ class Trainer(BaseTrainer): ...@@ -37,14 +36,22 @@ class Trainer(BaseTrainer):
A mutator object that has been initialized with the model. A mutator object that has been initialized with the model.
loss : callable loss : callable
Called with logits and targets. Returns a loss tensor. Called with logits and targets. Returns a loss tensor.
See `PyTorch loss functions`_ for examples.
metrics : callable metrics : callable
Returns a dict that maps metrics keys to metrics data. Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
.. code-block:: python
def metrics_fn(output, target):
return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
optimizer : Optimizer optimizer : Optimizer
Optimizer that optimizes the model. Optimizer that optimizes the model.
num_epochs : int num_epochs : int
Number of epochs of training. Number of epochs of training.
dataset_train : torch.utils.data.Dataset dataset_train : torch.utils.data.Dataset
Dataset of training. Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
PyTorch Dataset. See `torch.utils.data`_ for examples.
dataset_valid : torch.utils.data.Dataset dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing. Dataset of validation/testing.
batch_size : int batch_size : int
...@@ -58,8 +65,13 @@ class Trainer(BaseTrainer): ...@@ -58,8 +65,13 @@ class Trainer(BaseTrainer):
Number of mini-batches to log metrics. Number of mini-batches to log metrics.
callbacks : list of Callback callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks. Callbacks to plug into the trainer. See Callbacks.
"""
.. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
.. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
"""
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model self.model = model
self.mutator = mutator self.mutator = mutator
...@@ -84,13 +96,38 @@ class Trainer(BaseTrainer): ...@@ -84,13 +96,38 @@ class Trainer(BaseTrainer):
@abstractmethod @abstractmethod
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch):
"""
Train one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass pass
@abstractmethod @abstractmethod
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
"""
Validate one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass pass
def train(self, validate=True): def train(self, validate=True):
"""
Train ``num_epochs``.
Trigger callbacks at the start and the end of each epoch.
Parameters
----------
validate : bool
If ``true``, will do validation every epoch.
"""
for epoch in range(self.num_epochs): for epoch in range(self.num_epochs):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_epoch_begin(epoch) callback.on_epoch_begin(epoch)
...@@ -108,12 +145,26 @@ class Trainer(BaseTrainer): ...@@ -108,12 +145,26 @@ class Trainer(BaseTrainer):
callback.on_epoch_end(epoch) callback.on_epoch_end(epoch)
def validate(self): def validate(self):
"""
Do one validation.
"""
self.validate_one_epoch(-1) self.validate_one_epoch(-1)
def export(self, file): def export(self, file):
"""
Call ``mutator.export()`` and dump the architecture to ``file``.
Parameters
----------
file : str
A file path. Expected to be a JSON.
"""
mutator_export = self.mutator.export() mutator_export = self.mutator.export()
with open(file, "w") as f: with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self): def checkpoint(self):
"""
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet") raise NotImplementedError("Not implemented yet")
...@@ -12,12 +12,18 @@ _logger = logging.getLogger(__name__) ...@@ -12,12 +12,18 @@ _logger = logging.getLogger(__name__)
def global_mutable_counting(): def global_mutable_counting():
"""
A program level counter starting from 1.
"""
global _counter global _counter
_counter += 1 _counter += 1
return _counter return _counter
def to_device(obj, device): def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj): if torch.is_tensor(obj):
return obj.to(device) return obj.to(device)
if isinstance(obj, tuple): if isinstance(obj, tuple):
...@@ -32,12 +38,18 @@ def to_device(obj, device): ...@@ -32,12 +38,18 @@ def to_device(obj, device):
class AverageMeterGroup: class AverageMeterGroup:
"""Average meter group for multiple average meters""" """
Average meter group for multiple average meters.
"""
def __init__(self): def __init__(self):
self.meters = OrderedDict() self.meters = OrderedDict()
def update(self, data): def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items(): for k, v in data.items():
if k not in self.meters: if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f") self.meters[k] = AverageMeter(k, ":4f")
...@@ -53,15 +65,15 @@ class AverageMeterGroup: ...@@ -53,15 +65,15 @@ class AverageMeterGroup:
return " ".join(str(v) for v in self.meters.values()) return " ".join(str(v) for v in self.meters.values())
def summary(self): def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values()) return " ".join(v.summary() for v in self.meters.values())
class AverageMeter: class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
""" """
Initialization of AverageMeter Computes and stores the average and current value.
Parameters Parameters
---------- ----------
...@@ -70,17 +82,32 @@ class AverageMeter: ...@@ -70,17 +82,32 @@ class AverageMeter:
fmt : str fmt : str
Format string to print the values. Format string to print the values.
""" """
def __init__(self, name, fmt=':f'):
self.name = name self.name = name
self.fmt = fmt self.fmt = fmt
self.reset() self.reset()
def reset(self): def reset(self):
"""
Reset the meter.
"""
self.val = 0 self.val = 0
self.avg = 0 self.avg = 0
self.sum = 0 self.sum = 0
self.count = 0 self.count = 0
def update(self, val, n=1): def update(self, val, n=1):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
if not isinstance(val, float) and not isinstance(val, int): if not isinstance(val, float) and not isinstance(val, int):
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val)) _logger.warning("Values passed to AverageMeter must be number, not %s.", type(val))
self.val = val self.val = val
...@@ -104,6 +131,11 @@ class StructuredMutableTreeNode: ...@@ -104,6 +131,11 @@ class StructuredMutableTreeNode:
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``). ``Mutable`` (other than ``MutableScope``).
Parameters
----------
mutable : nni.nas.pytorch.mutables.Mutable
The mutable that current node is linked with.
""" """
def __init__(self, mutable): def __init__(self, mutable):
...@@ -111,10 +143,16 @@ class StructuredMutableTreeNode: ...@@ -111,10 +143,16 @@ class StructuredMutableTreeNode:
self.children = [] self.children = []
def add_child(self, mutable): def add_child(self, mutable):
"""
Add a tree node to the children list of current node.
"""
self.children.append(StructuredMutableTreeNode(mutable)) self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1] return self.children[-1]
def type(self): def type(self):
"""
Return the ``type`` of mutable content.
"""
return type(self.mutable) return type(self.mutable)
def __iter__(self): def __iter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment