"vscode:/vscode.git/clone" did not exist on "0acb8586643082b7f084ea9d91104ce6bf6e05b5"
Unverified Commit 05c7d6e9 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Typehint for oneshot NAS (#4811)

parent cbac2c5c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import warnings
from itertools import chain
from typing import Dict, Callable, List, Union, Any, Tuple
from typing import Callable, Any, Dict, Union, Tuple, List, cast
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
......@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[
def traverse_and_mutate_submodules(
root_module: nn.Module, hooks: List[MutationHook], mutate_kwargs: Dict[str, Any], topdown: bool = True
) -> List[BaseSuperNetModule]:
root_module: nn.Module, hooks: list[MutationHook], mutate_kwargs: dict[str, Any], topdown: bool = True
) -> list[BaseSuperNetModule]:
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
......@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules(
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
hooks : List[MutationHook]
hooks : list[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
......@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules(
Returns
----------
modules : Dict[str, nn.Module]
modules : dict[str, nn.Module]
The replace result.
"""
memo = {}
......@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules(
return module_list
def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> bool:
def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> bool:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
......@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
if is_traceable(module):
# check whether there is a value-choice in its arguments
has_valuechoice = False
for arg in chain(module.trace_args, module.trace_kwargs.values()):
for arg in chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
break
......@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
class BaseOneShotLightningModule(pl.LightningModule):
_mutation_hooks_note = """mutation_hooks : List[MutationHook]
_mutation_hooks_note = """mutation_hooks : list[MutationHook]
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
......@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes
----------
nas_modules : List[BaseSuperNetModule]
nas_modules : list[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
Parameters
----------
""" + _inner_module_note + _mutation_hooks_note
automatic_optimization = False
trainer: pl.Trainer
@property
def automatic_optimization(self) -> bool:
return False
def default_mutation_hooks(self) -> List[MutationHook]:
def default_mutation_hooks(self) -> list[MutationHook]:
"""Override this to define class-default mutation hooks."""
return [no_default_hook]
def mutate_kwargs(self) -> Dict[str, Any]:
def mutate_kwargs(self) -> dict[str, Any]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return {}
def __init__(self, base_model: pl.LightningModule, mutation_hooks: List[MutationHook] = None):
def __init__(self, model: pl.LightningModule, mutation_hooks: list[MutationHook] | None = None):
super().__init__()
assert isinstance(base_model, pl.LightningModule)
self.model = base_model
assert isinstance(model, pl.LightningModule)
self.model = model
# append the default hooks
mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks()
# traverse the model, calling hooks on every submodule
self.nas_modules: List[BaseSuperNetModule] = traverse_and_mutate_submodules(
self.nas_modules: list[BaseSuperNetModule] = traverse_and_mutate_submodules(
self.model, mutation_hooks, self.mutate_kwargs(), topdown=True)
def search_space_spec(self) -> Dict[str, ParameterSpec]:
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""Get the search space specification from ``nas_module``.
Returns
......@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.search_space_spec())
return result
def resample(self) -> Dict[str, Any]:
def resample(self) -> dict[str, Any]:
"""Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing.
......@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.resample(memo=result))
return result
def export(self) -> Dict[str, Any]:
def export(self) -> dict[str, Any]:
"""
Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
......@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers())
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
try:
# above v1.6
from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error
_configure_optimizers, # type: ignore
_configure_schedulers_automatic_opt, # type: ignore
_configure_schedulers_manual_opt # type: ignore
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers):
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
......@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_train_start()
......@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Returns
----------
arc_optimizers : List[Optimizer], Optimizer
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return None
......@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
"""
def apply(lr_scheduler):
# single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler) and \
self.trainer.is_last_batch:
lr_schedulers.step()
if isinstance(lr_scheduler, _LRScheduler):
if self.trainer.is_last_batch:
lr_scheduler.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
......@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler.step()
lr_scheduler['scheduler'].step()
lr_schedulers = self.lr_schedulers()
......@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else:
apply(lr_schedulers)
def call_user_optimizers(self, method):
def call_weight_optimizers(self, method):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
......@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
elif method == 'zero_grad':
optimizer.zero_grad()
optimizers = self.user_optimizers
optimizers = self.weight_optimizers()
if optimizers is None:
return
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
if len(self.frequencies) > 0:
self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
......@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
for optimizer in optimizers:
apply_method(optimizer, method)
@property
def architecture_optimizers(self):
def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers.
"""
......@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
arc_opts = arc_opts[0]
return arc_opts
return cast(Optimizer, arc_opts[0])
return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return opts
return cast(Union[List[Optimizer], Optimizer], opts)
return None
@property
def user_optimizers(self):
def weight_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
return opts[self.arc_optim_count:]
return cast(List[Optimizer], opts[self.arc_optim_count:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0:
return opts
return cast(Union[List[Optimizer], Optimizer], opts)
return None
......@@ -5,6 +5,7 @@
import copy
import logging
import warnings
from collections import OrderedDict
import torch
......@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer):
learning_rate=2.5E-3, batch_size=64, workers=4,
device=None, log_frequency=None,
arc_learning_rate=3.0E-4, unrolled=False):
warnings.warn('DartsTrainer is deprecated. Please use strategy.DARTS instead.', DeprecationWarning)
self.model = model
self.loss = loss
self.metrics = metrics
......
......@@ -3,9 +3,11 @@
"""Experimental version of differentiable one-shot implementation."""
from typing import List
from __future__ import annotations
import pytorch_lightning as pl
import torch
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import (
......@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> List[MutationHook]:
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
......@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule):
}
def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: List[MutationHook] = None,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4):
self.arc_learning_rate = arc_learning_rate
super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx):
# grad manually
arc_optim = self.architecture_optimizers
arc_optim = self.architecture_optimizers()
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch, val_batch = batch
......@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
# phase 2: model step
self.resample()
self.call_user_optimizers('zero_grad')
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_user_optimizers('step')
self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx)
......@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = []
for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True))
ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
return ctrl_optim
......@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> List[MutationHook]:
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
ProxylessMixedLayer.mutate,
......@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule):
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad()
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule):
......@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
def default_mutation_hooks(self) -> List[MutationHook]:
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
......@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
}
def __init__(self, inner_module,
mutation_hooks: List[MutationHook] = None,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4,
gumbel_temperature: float = 1.,
use_temp_anneal: bool = False,
......@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule):
self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp
def on_epoch_start(self):
def on_train_epoch_end(self):
if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp)
for module in self.nas_modules:
module._softmax.temp = self.temp
if hasattr(module, '_softmax'):
module._softmax.temp = self.temp # type: ignore
return self.model.on_epoch_start()
return self.model.on_train_epoch_end()
......@@ -2,10 +2,14 @@
# Licensed under the MIT license.
import logging
import warnings
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import SubsetRandomSampler, DataLoader
from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
......@@ -113,9 +117,9 @@ class ReinforceController(nn.Module):
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
......@@ -143,7 +147,7 @@ class ReinforceController(nn.Module):
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
......@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer):
batch_size=64, workers=4, device=None, log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None):
warnings.warn('EnasTrainer is deprecated. Please use strategy.ENAS instead.', DeprecationWarning)
self.model = model
self.loss = loss
self.metrics = metrics
......@@ -246,16 +252,16 @@ class EnasTrainer(BaseOneShotTrainer):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
train_sampler = SubsetRandomSampler(indices[:-split])
valid_sampler = SubsetRandomSampler(indices[-split:])
self.train_loader = DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_model(self, epoch):
self.model.train()
......@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer):
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
reward += self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
metrics['reward'] = reward
metrics['loss'] = loss.item()
metrics['ent'] = self.controller.sample_entropy.item()
metrics['log_prob'] = self.controller.sample_log_prob.item()
metrics['ent'] = self.controller.sample_entropy.item() # type: ignore
metrics['log_prob'] = self.controller.sample_log_prob.item() # type: ignore
metrics['baseline'] = self.baseline
metrics['skip'] = self.controller.sample_skip_penalty
......
......@@ -4,6 +4,7 @@
# type: ignore
import logging
import warnings
import torch
import torch.nn as nn
......@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
grad_reg_loss_type=None, grad_reg_loss_params=None,
applied_hardware=None, dummy_input=(1, 3, 224, 224),
ref_latency=65.0):
warnings.warn('ProxylessTrainer is deprecated. Please use strategy.Proxyless instead.', DeprecationWarning)
self.model = model
self.loss = loss
self.metrics = metrics
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import logging
import random
import warnings
import torch
import torch.nn as nn
......@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer):
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
batch_size=64, workers=4, device=None, log_frequency=None):
warnings.warn('SinglePathTrainer is deprecated. Please use strategy.RandomOneShot instead.', DeprecationWarning)
self.model = model
self.loss = loss
self.metrics = metrics
......
......@@ -3,7 +3,8 @@
"""Experimental version of sampling-based one-shot implementation."""
from typing import Dict, Any, List
from __future__ import annotations
from typing import Any
import pytorch_lightning as pl
import torch
......@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
)
# turn on automatic optimization because nothing interesting is going on here.
automatic_optimization = True
@property
def automatic_optimization(self) -> bool:
return True
def default_mutation_hooks(self) -> List[MutationHook]:
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
PathSamplingLayer.mutate,
......@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``,
which is "default" by default.
Otherwise it raises an exception indicating multiple metrics are found.
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _enas_note.format(
......@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note,
)
automatic_optimization = False
@property
def automatic_optimization(self) -> bool:
return False
def __init__(self,
inner_module: pl.LightningModule,
*,
ctrl_kwargs: Dict[str, Any] = None,
ctrl_kwargs: dict[str, Any] | None = None,
entropy_weight: float = 1e-4,
skip_weight: float = .8,
baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0,
mutation_hooks: List[MutationHook] = None):
reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks)
# convert parameter spec to legacy ReinforceField
# this part will be refactored
self.nas_fields: List[ReinforceField] = []
self.nas_fields: list[ReinforceField] = []
for name, param_spec in self.search_space_spec().items():
if param_spec.chosen_size not in (1, None):
raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.')
......@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip
self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
......@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule):
if source == 'train':
# step 1: train model params
self.resample()
self.call_user_optimizers('zero_grad')
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_user_optimizers('step')
self.call_weight_optimizers('step')
return loss_and_metrics
if source == 'val':
# step 2: train ENAS agent
x, y = batch
arc_opt = self.architecture_optimizers
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
with torch.no_grad():
logits = self.model(x)
self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function
if len(self.model.metrics) == 1:
_, metric = next(iter(self.model.metrics.items()))
if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.trainer.callback_metrics.items()))
else:
if 'default' not in self.model.metrics.keys():
raise KeyError('model.metrics should contain a ``default`` key when'
'there are multiple metrics')
metric = self.model.metrics['default']
metric_name = self.reward_metric_name or 'default'
if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple metrics without default: {self.trainer.callback_metrics.keys()}')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
reward = metric(logits, y)
if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item()
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
......@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
with torch.no_grad():
return self._interpret_controller_sampling_result(self.controller.resample())
def _interpret_controller_sampling_result(self, sample: Dict[str, int]) -> Dict[str, Any]:
def _interpret_controller_sampling_result(self, sample: dict[str, int]) -> dict[str, Any]:
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec = self.search_space_spec()
for key in list(sample.keys()):
......
......@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
from __future__ import annotations
import warnings
from typing import Any, List, Optional, Type, Union, Tuple
from typing import Any, Type
import torch.nn as nn
from torch.utils.data import DataLoader
......@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy):
self.oneshot_module = oneshot_module
self.oneshot_kwargs = kwargs
self.model: Optional[BaseOneShotLightningModule] = None
self.model: BaseOneShotLightningModule | None = None
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader) \
-> Union[DataLoader, Tuple[DataLoader, DataLoader]]:
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader | list[DataLoader]) \
-> DataLoader | tuple[DataLoader, DataLoader]:
"""
One-shot strategy typically requires a customized dataloader.
......@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy):
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model: nn.Module = base_model.python_object
if not isinstance(py_model, nn.Module):
if not isinstance(base_model.python_object, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason)
......@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy):
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.set_model(py_model)
self.model: BaseOneShotLightningModule = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloader is None or evaluator.val_dataloaders is None:
raise TypeError('Train or val dataloader is not set.')
dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders)
if isinstance(dataloader, tuple):
dataloader, val_loader = dataloader
......@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy):
else:
evaluator.trainer.fit(self.model, dataloader)
def export_top_models(self, top_k: int = 1) -> List[Any]:
def export_top_models(self, top_k: int = 1) -> list[Any]:
if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1:
......
......@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
from __future__ import annotations
import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast
from typing import Callable, Iterator, TypeVar, Any, Optional, Tuple, Union, List, Dict, Generic, cast
import numpy as np
import torch
......@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
for i in range(len(slice_)):
if isinstance(slice_[i], list):
# convert list of slices to mask
mask = np.zeros(shape[i], dtype=np.bool)
for sl in slice_[i]:
mask = np.zeros(shape[i], dtype=np.bool) # type: ignore
for sl in cast(List[slice], slice_[i]):
mask[sl] = 1
result.append(mask)
else:
......@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
return tuple(result)
def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_slice, float]]]) -> T:
def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice, float]]) -> T:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
......@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s
# create a mask with weight w
with torch.no_grad():
mask = zeros_like(weight)
mask[_eliminate_list_slice(weight.shape, sl)] = 1
mask[_eliminate_list_slice(weight.shape, sl)] = 1 # type: ignore
# track gradients here
masks.append((mask * wt))
masks.append(mask * wt) # type: ignore
masks = sum(masks)
return masks * weight
return masks * weight # type: ignore
else:
# for unweighted case, we slice it directly.
def _do_slice(arr, slice_):
return arr[_eliminate_list_slice(arr.shape, slice_)]
return arr[_eliminate_list_slice(arr.shape, slice_)] # type: ignore
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=np.bool)
no_effect = _do_slice(dummy_arr, slice_).shape == dummy_arr.shape
dummy_arr = np.zeros(weight.shape, dtype=np.bool) # type: ignore
no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect:
return weight
......@@ -128,14 +130,14 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T:
def __getitem__(self, index: slice_type | multidim_slice) -> T:
if not isinstance(index, tuple):
index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict: Optional[Dict[int, float]] = None
leaf_dict: dict[int, float] | None = None
for maybe_weighted in _iterate_over_multidim_slice(index):
for d in maybe_weighted.leaf_values():
if isinstance(d, dict):
......@@ -166,10 +168,10 @@ class MaybeWeighted:
"""
def __init__(self,
value: Optional[int_or_int_dict] = None, *,
lhs: Optional[Union['MaybeWeighted', int]] = None,
rhs: Optional[Union['MaybeWeighted', int]] = None,
operation: Optional[Callable[[int, int], int]] = None):
value: int_or_int_dict | None = None, *,
lhs: 'MaybeWeighted' | int | None = None,
rhs: 'MaybeWeighted' | int | None = None,
operation: Callable[[int_or_int_dict, int_or_int_dict], int_or_int_dict] | None = None):
if operation is None:
if not isinstance(value, (int, dict)):
raise TypeError(f'Unsupported value type: {type(value)}')
......@@ -178,7 +180,7 @@ class MaybeWeighted:
self.rhs = rhs
self.operation = operation
def leaf_values(self) -> Iterator[Dict[int, float]]:
def leaf_values(self) -> Iterator[int_or_int_dict]:
"""Iterate over values on leaf nodes."""
if self.value is not None:
yield self.value
......@@ -188,7 +190,7 @@ class MaybeWeighted:
if isinstance(self.rhs, MaybeWeighted):
yield from self.rhs.leaf_values()
def evaluate(self, value_fn: _value_fn_type = None) -> int:
def evaluate(self, value_fn: _value_fn_type = None) -> int_or_int_dict:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
......@@ -200,11 +202,12 @@ class MaybeWeighted:
if isinstance(self.lhs, MaybeWeighted):
eval_lhs = self.lhs.evaluate(value_fn)
else:
eval_lhs = self.lhs
eval_lhs = cast(int, self.lhs)
if isinstance(self.rhs, MaybeWeighted):
eval_rhs = self.rhs.evaluate(value_fn)
else:
eval_rhs = self.rhs
eval_rhs = cast(int, self.rhs)
assert self.operation is not None
return self.operation(eval_lhs, eval_rhs)
def __repr__(self):
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
......@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d):
----------
input_weight : Tensor
the weight to be weighted summed
masks : List[Tensor]
masks : list[Tensor]
weight masks.
thresholds : List[float]
thresholds : list[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
......
......@@ -4,19 +4,26 @@
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations
import itertools
from typing import List, Any, Dict, Tuple, Optional, Union
from typing import Any, TypeVar, List, cast
import numpy as np
import torch
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.api import ChoiceOf, ValueChoiceX
Choice = Any
T = TypeVar('T')
__all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options']
def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, ParameterSpec]:
def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
......@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete
return result
def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str, Choice]) -> Any:
def evaluate_value_choice_with_dict(value_choice: ChoiceOf[T], chosen: dict[str, Choice]) -> T:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
......@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str
return value_choice.evaluate(choice_inner_values)
def traverse_all_options(value_choice: ValueChoiceX,
weights: Optional[Dict[str, List[float]]] = None) -> List[Union[Tuple[Any, float], Any]]:
def traverse_all_options(
value_choice: ChoiceOf[T],
weights: dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
) -> list[tuple[T, float]] | list[T]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
......@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX,
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[Dict[str, List[float]]], default = None
weights : Optional[dict[str, list[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
List[Union[Tuple[Any, float], Any]]
list[Union[tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs: Dict[str, List[Tuple[Choice, float]]] = {}
leafs: dict[str, list[tuple[T, float]]] = {}
for label, param_spec in dedup_inner_choices([value_choice]).items():
if weights is not None:
if label not in weights:
raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}')
if len(weights[label]) != param_spec.size:
raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found')
leafs[label] = list(zip(param_spec.values, weights[label]))
leafs[label] = list(zip(param_spec.values, cast(List[float], weights[label])))
else:
# create a dummy weight of zero, in case that weights are not provided.
leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size)))
# result is a dict from a option to its weight
result: Dict[str, Optional[float]] = {}
result: dict[T, float | None] = {}
labels, values = list(leafs.keys()), list(leafs.values())
if not labels:
......@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX,
result[eval_res] = chosen_weight
if weights is None:
return sorted(result.keys())
return sorted(result.keys()) # type: ignore
else:
return sorted(result.items())
return sorted(result.items()) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, Tuple, Union
from __future__ import annotations
from typing import Any
import torch.nn as nn
......@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
"""
def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]:
def resample(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Resample the super-net module.
Parameters
----------
memo : Dict[str, Any]
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
......@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module):
"""
raise NotImplementedError()
def export(self, memo: Dict[str, Any]) -> Dict[str, Any]:
def export(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : Dict[str, Any]
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> Dict[str, ParameterSpec]:
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
......@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module):
raise NotImplementedError()
@classmethod
def mutate(cls, module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> \
Union['BaseSuperNetModule', bool, Tuple['BaseSuperNetModule', bool]]:
def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
......@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module):
Returns
-------
Union[BaseSuperNetModule, bool, Tuple[BaseSuperNetModule, bool]]
Union[BaseSuperNetModule, bool, tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import functools
import warnings
from typing import List, Tuple, Optional, Dict, Any, Union
from typing import Any, cast
import torch
import torch.nn as nn
......@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options
class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def __init__(self, dim: Optional[int] = -1) -> None:
dim: int
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
self.tau = 1
self.hard = False
......@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Parameters
----------
paths : List[Tuple[str, nn.Module]]
paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
......@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Name of the choice.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
......@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
......@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule):
Name of the choice.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
......@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
won't be optimized.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None:
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
......@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, self=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, self=operation)
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
def parameters(module, *args, **kwargs):
for _, p in module.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(self, *args, **kwargs):
def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(self.__class__, self).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
......@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in result:
continue
chosen_index = torch.argmax(operation._arch_alpha[name]).item()
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Union[Dict[Any, float], Any]:
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments:
weights = {label: operation._softmax(alpha) for label, alpha in operation._arch_alpha.items()}
weights: dict[str, torch.Tensor] = {
label: cast(nn.Module, operation._softmax)(alpha) for label, alpha in cast(dict, operation._arch_alpha).items()
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
......@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
from __future__ import annotations
import inspect
import itertools
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast
from typing import Any, Type, TypeVar, cast, Union, Tuple
import torch
import torch.nn as nn
......@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy:
One SamplingStrategy corresponds to one mixed operation.
"""
def __init__(self, operation: 'MixedOperation', memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None:
def __init__(self, operation: 'MixedOperation', memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
......@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy:
"""
pass
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
def resample(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
def export(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
......@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule):
"""
bound_type: Type[nn.Module] # defined in subclass
argument_list: List[str] # defined in subclass
argument_list: list[str] # defined in subclass
sampling_policy: MixedOperationSamplingPolicy
......@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule):
appended by forward arguments in the ``bound_type``."""
raise NotImplementedError()
def __init__(self, module_kwargs: Dict[str, Any]) -> None:
def __init__(self, module_kwargs: dict[str, Any]) -> None:
# Concerned arguments
self.mutable_arguments: Dict[str, ValueChoiceX] = {}
self.mutable_arguments: dict[str, ValueChoiceX] = {}
# Useful when retrieving arguments without ValueChoice
self.init_arguments: Dict[str, Any] = {**module_kwargs}
self.init_arguments: dict[str, Any] = {**module_kwargs}
self._fill_missing_init_arguments()
# get init default
......@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule):
super_init_kwargs[key] = value
# get all inner leaf value choices
self._space_spec: Dict[str, ParameterSpec] = dedup_inner_choices(self.mutable_arguments.values())
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
super().__init__(**super_init_kwargs)
......@@ -156,17 +158,17 @@ class MixedOperation(BaseSuperNetModule):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice = False
if isinstance(module, cls.bound_type) and is_traceable(module):
for arg in itertools.chain(module.trace_args, module.trace_kwargs.values()):
for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
if has_valuechoice:
if module.trace_args:
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.')
'Please enable ``kw_only`` on nni.trace.')
# save type and kwargs
mixed_op = cls(module.trace_kwargs)
mixed_op = cls(cast(dict, module.trace_kwargs))
if 'mixed_op_sampling' not in mutate_kwargs:
raise ValueError('Need to sampling policy of mixed op, but not found in `mutate_kwargs`.')
......@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear):
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
in_features = _W(in_features)
out_features = _W(out_features)
in_features_ = _W(in_features)
out_features_ = _W(out_features)
weight = _S(self.weight)[:out_features]
weight = _S(weight)[:, :in_features]
weight = _S(self.weight)[:out_features_]
weight = _S(weight)[:, :in_features_]
if self.bias is None:
bias = self.bias
else:
bias = _S(self.bias)[:out_features]
bias = _S(self.bias)[:out_features_]
return F.linear(inputs, weight, bias)
......@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
]
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[T]) -> Tuple[T, T]:
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
......@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if any(isinstance(arg, dict) for arg in [stride, dilation, groups]):
raise ValueError('stride, dilation, groups does not support weighted sampling.')
in_channels = _W(in_channels)
out_channels = _W(out_channels)
in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels]
weight = _S(weight)[:, :in_channels // groups]
weight = _S(self.weight)[:out_channels_]
weight = _S(weight)[:, :in_channels_ // groups]
# slice center
if isinstance(kernel_size, dict):
# If kernel size is a dict, ignore choices in padding.
if isinstance(self.padding, str):
raise ValueError(f'Use "{self.padding}" in padding is not supported.')
padding = self.padding # max padding, must be a tuple
kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a, kernel_b = _W(kernel_a), _W(kernel_b)
kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a) // 2, (max_kernel_b - kernel_b) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a, kernel_b_top:kernel_b_top + kernel_b]
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels] if self.bias is not None else None
bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
# The rest parameters only need to be converted to tuple
stride = self._to_tuple(stride)
dilation = self._to_tuple(dilation)
stride_ = self._to_tuple(stride)
dilation_ = self._to_tuple(dilation)
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(inputs, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, stride, (0, 0), dilation, groups)
return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups)
weight, bias, stride_, (0, 0), dilation_, groups)
return F.conv2d(inputs, weight, bias, stride_, cast('int | tuple', padding), dilation_, groups)
class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
......@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if num_features < self.num_features:
weight = weight[:num_features]
bias = bias[:num_features]
running_mean = running_mean[:num_features]
running_var = running_var[:num_features]
if running_mean is not None:
running_mean = running_mean[:num_features]
if running_var is not None:
running_var = running_var[:num_features]
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
bn_training = (running_mean is None) and (running_var is None)
return F.batch_norm(
inputs,
......@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def _to_proj_slice(self, embed_dim: _W) -> List[slice]:
def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
# slice three parts, corresponding to q, k, v respectively
return [
slice(embed_dim),
......@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def forward_with_args(
self,
embed_dim: int_or_int_dict, num_heads: int,
kdim: Optional[int_or_int_dict], vdim: Optional[int_or_int_dict],
kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
dropout: float,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
key_padding_mask: torch.Tensor | None = None,
need_weights: bool = True, attn_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError('num_heads, dropout do not support weighted sampling.')
......@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
else:
used_embed_dim = embed_dim
embed_dim = _W(embed_dim)
embed_dim_ = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias: Optional[Tensor] = None
in_proj_weight: Optional[Tensor] = None
in_proj_bias: Tensor | None = None
in_proj_weight: Tensor | None = None
if self.in_proj_bias is not None:
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)]
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim_)]
if self.in_proj_weight is not None:
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim]
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim_), :embed_dim_]
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim_] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim_] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim_, :embed_dim_]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim_] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim:
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim]
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim_, :embed_dim_]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim_]
k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch
......@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return attn_output, attn_output_weights
NATIVE_MIXED_OPERATIONS: List[Type[MixedOperation]] = [
NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear,
MixedConv2d,
MixedBatchNorm2d,
......
......@@ -9,7 +9,9 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
"""
from typing import List, Tuple, Optional, cast
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
......@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# like sampling-based methods, it has a ``_sampled``.
self._sampled: Optional[str] = None
self._sample_idx: Optional[int] = None
self._sampled: str | None = None
self._sample_idx: int | None = None
def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs):
......@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput):
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
self._sampled: Optional[int] = None
self._sampled: int | None = None
def forward(self, inputs):
def run_function(active_sample):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import random
from typing import Optional, List, Tuple, Union, Dict, Any
from typing import Any
import torch
import torch.nn as nn
......@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule):
Name of the choice.
"""
def __init__(self, paths: List[Tuple[str, nn.Module]], label: str):
def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
super().__init__()
self.op_names = []
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self._sampled: Optional[Union[List[str], str]] = None # sampled can be either a list of indices or an index
self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
self.label = label
def resample(self, memo):
......@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule):
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self._sampled: Optional[Union[List[int], int]] = None
self._sampled: list[int] | int | None = None
self.label = label
def _random_choose_n(self):
......@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
We sample the leaf nodes, and composits them into the values on arguments.
"""
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None:
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None
self._sampled: dict[str, Any] | None = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
......@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from collections import OrderedDict
from typing import cast
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
......@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
Returns
-------
List[Tuple[str, nn.Module]]
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
......@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
Returns
-------
List[Tuple[str, nn.Module]]
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
......@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader):
Example
--------
Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader.
"""
def __init__(self, train_dataloader, val_dataloader):
self.train_loader = train_dataloader
self.val_loader = val_dataloader
def __init__(self, train_dataloader: DataLoader, val_dataloader: DataLoader | list[DataLoader]):
if isinstance(val_dataloader, list):
raise TypeError('Validation dataloader of type list is not supported.')
self.train_loader: DataLoader = train_dataloader
self.val_loader: DataLoader = val_dataloader
self.equal_len = len(train_dataloader) == len(val_dataloader)
self.train_longer = len(train_dataloader) > len(val_dataloader)
super().__init__(None)
super().__init__(cast(Dataset, None))
def __iter__(self):
self.train_iter = iter(self.train_loader)
......@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader):
Example
--------
Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader.
"""
def __init__(self, train_dataloader, val_dataloader):
self.train_loader = train_dataloader
self.val_loader = val_dataloader
super().__init__(None)
def __init__(self, train_dataloader: DataLoader, val_dataloader: DataLoader | list[DataLoader]):
if isinstance(val_dataloader, list):
raise TypeError('Validation dataloader of type list is not supported.')
self.train_loader: DataLoader = train_dataloader
self.val_loader: DataLoader = val_dataloader
super().__init__(cast(Dataset, None))
def __iter__(self):
self.cur_iter = iter(self.train_loader)
......
......@@ -14,7 +14,6 @@
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
"nni/retiarii/oneshot",
"nni/smartparam.py",
"nni/tools/annotation",
"nni/tools/gpu_tool",
......
......@@ -130,13 +130,16 @@ def test_fit_api():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.1, # for faster training
progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning.fit(lambda: MNISTModel())
lightning.fit(MNISTModel)
lightning.fit(MNISTModel())
def lightning(): return pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.1, # for faster training
progress_bar_refresh_rate=progress_bar_refresh_rate)
# Lightning will have some cache in models / trainers,
# which is problematic if we call fit multiple times.
lightning().fit(lambda: MNISTModel())
lightning().fit(MNISTModel)
lightning().fit(MNISTModel())
_reset()
......
......@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.strategy import BaseStrategy
class DepthwiseSeparableConv(nn.Module):
......@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True):
]
for (base_model, evaluator), support_or_not in to_test:
print('Testing:', type(strategy_).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not)
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy_)
if isinstance(strategy_, BaseStrategy):
strategy = strategy_
else:
strategy = strategy_(base_model, evaluator)
print('Testing:', type(strategy).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not)
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
......@@ -263,7 +268,12 @@ def test_proxyless():
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas():
_test_strategy(strategy.ENAS())
def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet):
return strategy.ENAS(reward_metric_name='val_mse')
return strategy.ENAS(reward_metric_name='val_acc')
_test_strategy(strategy_fn)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment