"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "653aa75e02068118579ebc3179c6f834919eb1e5"
Unverified Commit 05c7d6e9 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Typehint for oneshot NAS (#4811)

parent cbac2c5c
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import warnings import warnings
from itertools import chain from itertools import chain
from typing import Dict, Callable, List, Union, Any, Tuple from typing import Callable, Any, Dict, Union, Tuple, List, cast
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.optim as optim import torch.optim as optim
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
...@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[ ...@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[
def traverse_and_mutate_submodules( def traverse_and_mutate_submodules(
root_module: nn.Module, hooks: List[MutationHook], mutate_kwargs: Dict[str, Any], topdown: bool = True root_module: nn.Module, hooks: list[MutationHook], mutate_kwargs: dict[str, Any], topdown: bool = True
) -> List[BaseSuperNetModule]: ) -> list[BaseSuperNetModule]:
""" """
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node. Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
...@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules( ...@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules(
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`, Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``. it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``. The mutation will be in-place on ``root_module``.
hooks : List[MutationHook] hooks : list[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks. List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module. When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict mutate_kwargs : dict
...@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules( ...@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules(
Returns Returns
---------- ----------
modules : Dict[str, nn.Module] modules : dict[str, nn.Module]
The replace result. The replace result.
""" """
memo = {} memo = {}
...@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules( ...@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules(
return module_list return module_list
def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> bool: def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> bool:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives.""" """Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet # Forward IS NOT supernet
...@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k ...@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
if is_traceable(module): if is_traceable(module):
# check whether there is a value-choice in its arguments # check whether there is a value-choice in its arguments
has_valuechoice = False has_valuechoice = False
for arg in chain(module.trace_args, module.trace_kwargs.values()): for arg in chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX): if isinstance(arg, ValueChoiceX):
has_valuechoice = True has_valuechoice = True
break break
...@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k ...@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
class BaseOneShotLightningModule(pl.LightningModule): class BaseOneShotLightningModule(pl.LightningModule):
_mutation_hooks_note = """mutation_hooks : List[MutationHook] _mutation_hooks_note = """mutation_hooks : list[MutationHook]
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`. Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules. They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently, For each submodule, the hook list are invoked subsequently,
...@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes Attributes
---------- ----------
nas_modules : List[BaseSuperNetModule] nas_modules : list[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about. Modules that have been mutated, which the search algorithms should care about.
Parameters Parameters
---------- ----------
""" + _inner_module_note + _mutation_hooks_note """ + _inner_module_note + _mutation_hooks_note
automatic_optimization = False trainer: pl.Trainer
@property
def automatic_optimization(self) -> bool:
return False
def default_mutation_hooks(self) -> List[MutationHook]: def default_mutation_hooks(self) -> list[MutationHook]:
"""Override this to define class-default mutation hooks.""" """Override this to define class-default mutation hooks."""
return [no_default_hook] return [no_default_hook]
def mutate_kwargs(self) -> Dict[str, Any]: def mutate_kwargs(self) -> dict[str, Any]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific.""" """Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return {} return {}
def __init__(self, base_model: pl.LightningModule, mutation_hooks: List[MutationHook] = None): def __init__(self, model: pl.LightningModule, mutation_hooks: list[MutationHook] | None = None):
super().__init__() super().__init__()
assert isinstance(base_model, pl.LightningModule) assert isinstance(model, pl.LightningModule)
self.model = base_model self.model = model
# append the default hooks # append the default hooks
mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks() mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks()
# traverse the model, calling hooks on every submodule # traverse the model, calling hooks on every submodule
self.nas_modules: List[BaseSuperNetModule] = traverse_and_mutate_submodules( self.nas_modules: list[BaseSuperNetModule] = traverse_and_mutate_submodules(
self.model, mutation_hooks, self.mutate_kwargs(), topdown=True) self.model, mutation_hooks, self.mutate_kwargs(), topdown=True)
def search_space_spec(self) -> Dict[str, ParameterSpec]: def search_space_spec(self) -> dict[str, ParameterSpec]:
"""Get the search space specification from ``nas_module``. """Get the search space specification from ``nas_module``.
Returns Returns
...@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.search_space_spec()) result.update(module.search_space_spec())
return result return result
def resample(self) -> Dict[str, Any]: def resample(self) -> dict[str, Any]:
"""Trigger the resample for each ``nas_module``. """Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing. Sometimes (e.g., in differentiable cases), it does nothing.
...@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.resample(memo=result)) result.update(module.resample(memo=result))
return result return result
def export(self) -> Dict[str, Any]: def export(self) -> dict[str, Any]:
""" """
Export the NAS result, ideally the best choice of each ``nas_module``. Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``. You may implement an ``export`` method for your customized ``nas_module``.
...@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
arc_optimizers = [arc_optimizers] arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers) self.arc_optim_count = len(arc_optimizers)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires # The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually. # ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``. # For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
w_optimizers, lr_schedulers, self.frequencies, monitor = \ try:
self.trainer._configure_optimizers(self.model.configure_optimizers()) # above v1.6
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error
_configure_optimizers, # type: ignore
_configure_schedulers_automatic_opt, # type: ignore
_configure_schedulers_manual_opt # type: ignore
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers):
raise Exception( raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
...@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
# redirect the access to trainer/log to this module # redirect the access to trainer/log to this module
# but note that we might be missing other attributes, # but note that we might be missing other attributes,
# which could potentially be a problem # which could potentially be a problem
self.model.trainer = self.trainer self.model.trainer = self.trainer # type: ignore
self.model.log = self.log self.model.log = self.log
return self.model.on_train_start() return self.model.on_train_start()
...@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Returns Returns
---------- ----------
arc_optimizers : List[Optimizer], Optimizer arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed. Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
""" """
return None return None
...@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
""" """
def apply(lr_scheduler): def apply(lr_scheduler):
# single scheduler is called every epoch # single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler) and \ if isinstance(lr_scheduler, _LRScheduler):
self.trainer.is_last_batch: if self.trainer.is_last_batch:
lr_schedulers.step() lr_scheduler.step()
# lr_scheduler_config is called as configured # lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict): elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval'] interval = lr_scheduler['interval']
...@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
self.trainer.is_last_batch and self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0 (self.trainer.current_epoch + 1) % frequency == 0
): ):
lr_scheduler.step() lr_scheduler['scheduler'].step()
lr_schedulers = self.lr_schedulers() lr_schedulers = self.lr_schedulers()
...@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else: else:
apply(lr_schedulers) apply(lr_schedulers)
def call_user_optimizers(self, method): def call_weight_optimizers(self, method):
""" """
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer. class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
...@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
elif method == 'zero_grad': elif method == 'zero_grad':
optimizer.zero_grad() optimizer.zero_grad()
optimizers = self.user_optimizers optimizers = self.weight_optimizers()
if optimizers is None: if optimizers is None:
return return
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
if len(self.frequencies) > 0: if len(self.frequencies) > 0:
self.cur_optimizer_step += 1 self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step: if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
...@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
for optimizer in optimizers: for optimizer in optimizers:
apply_method(optimizer, method) apply_method(optimizer, method)
@property def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
def architecture_optimizers(self):
""" """
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``. Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns Returns
---------- ----------
opts : List[Optimizer], Optimizer, None opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers. architecture optimizers.
""" """
...@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count] arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1: if len(arc_opts) == 1:
arc_opts = arc_opts[0] return cast(Optimizer, arc_opts[0])
return arc_opts return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer # If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1: if self.arc_optim_count == 1:
return opts return cast(Union[List[Optimizer], Optimizer], opts)
return None return None
@property def weight_optimizers(self) -> list[Optimizer] | Optimizer | None:
def user_optimizers(self):
""" """
Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``. Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Returns Returns
---------- ----------
opts : List[Optimizer], Optimizer, None opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers. Optimizers defined by user's model. This will be None if there is no user optimizers.
""" """
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts = self.optimizers() opts = self.optimizers()
if isinstance(opts, list): if isinstance(opts, list):
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
return opts[self.arc_optim_count:] return cast(List[Optimizer], opts[self.arc_optim_count:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer # If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0: if self.arc_optim_count == 0:
return opts return cast(Union[List[Optimizer], Optimizer], opts)
return None return None
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import copy import copy
import logging import logging
import warnings
from collections import OrderedDict from collections import OrderedDict
import torch import torch
...@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer): ...@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer):
learning_rate=2.5E-3, batch_size=64, workers=4, learning_rate=2.5E-3, batch_size=64, workers=4,
device=None, log_frequency=None, device=None, log_frequency=None,
arc_learning_rate=3.0E-4, unrolled=False): arc_learning_rate=3.0E-4, unrolled=False):
warnings.warn('DartsTrainer is deprecated. Please use strategy.DARTS instead.', DeprecationWarning)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
"""Experimental version of differentiable one-shot implementation.""" """Experimental version of differentiable one-shot implementation."""
from typing import List from __future__ import annotations
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import ( from .supermodule.differentiable import (
...@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
def default_mutation_hooks(self) -> List[MutationHook]: def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions""" """Replace modules with differentiable versions"""
hooks = [ hooks = [
DifferentiableMixedLayer.mutate, DifferentiableMixedLayer.mutate,
...@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule):
} }
def __init__(self, inner_module: pl.LightningModule, def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: List[MutationHook] = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4): arc_learning_rate: float = 3.0E-4):
self.arc_learning_rate = arc_learning_rate self.arc_learning_rate = arc_learning_rate
super().__init__(inner_module, mutation_hooks=mutation_hooks) super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# grad manually # grad manually
arc_optim = self.architecture_optimizers arc_optim = self.architecture_optimizers()
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# The InterleavedTrainValDataLoader yields both train and val data in a batch # The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch, val_batch = batch trn_batch, val_batch = batch
...@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
# phase 2: model step # phase 2: model step
self.resample() self.resample()
self.call_user_optimizers('zero_grad') self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1) loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss) self.manual_backward(w_step_loss)
self.call_user_optimizers('step') self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx) self.call_lr_schedulers(batch_idx)
...@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer. # The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = [] ctrl_params = []
for m in self.nas_modules: for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999), ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
return ctrl_optim return ctrl_optim
...@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
def default_mutation_hooks(self) -> List[MutationHook]: def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions""" """Replace modules with gumbel-differentiable versions"""
hooks = [ hooks = [
ProxylessMixedLayer.mutate, ProxylessMixedLayer.mutate,
...@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule):
def finalize_grad(self): def finalize_grad(self):
for m in self.nas_modules: for m in self.nas_modules:
m.finalize_grad() m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule): class GumbelDartsLightningModule(DartsLightningModule):
...@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
def default_mutation_hooks(self) -> List[MutationHook]: def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions""" """Replace modules with gumbel-differentiable versions"""
hooks = [ hooks = [
DifferentiableMixedLayer.mutate, DifferentiableMixedLayer.mutate,
...@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
} }
def __init__(self, inner_module, def __init__(self, inner_module,
mutation_hooks: List[MutationHook] = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4, arc_learning_rate: float = 3.0e-4,
gumbel_temperature: float = 1., gumbel_temperature: float = 1.,
use_temp_anneal: bool = False, use_temp_anneal: bool = False,
...@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule):
self.use_temp_anneal = use_temp_anneal self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp self.min_temp = min_temp
def on_epoch_start(self): def on_train_epoch_end(self):
if self.use_temp_anneal: if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
for module in self.nas_modules: for module in self.nas_modules:
module._softmax.temp = self.temp if hasattr(module, '_softmax'):
module._softmax.temp = self.temp # type: ignore
return self.model.on_epoch_start() return self.model.on_train_epoch_end()
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import warnings
from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import SubsetRandomSampler, DataLoader
from ..interface import BaseOneShotTrainer from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice from .random import PathSamplingLayerChoice, PathSamplingInputChoice
...@@ -113,9 +117,9 @@ class ReinforceController(nn.Module): ...@@ -113,9 +117,9 @@ class ReinforceController(nn.Module):
self._h = [torch.zeros((1, self.lstm_size), self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype, dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)] device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0 self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy = 0 self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty = 0 self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self): def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
...@@ -143,7 +147,7 @@ class ReinforceController(nn.Module): ...@@ -143,7 +147,7 @@ class ReinforceController(nn.Module):
if sampled.sum().item(): if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0) self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else: else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist() sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob) self.sample_log_prob += self.entropy_reduction(log_prob)
...@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer): ...@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer):
batch_size=64, workers=4, device=None, log_frequency=None, batch_size=64, workers=4, device=None, log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None): ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None):
warnings.warn('EnasTrainer is deprecated. Please use strategy.ENAS instead.', DeprecationWarning)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
...@@ -246,16 +252,16 @@ class EnasTrainer(BaseOneShotTrainer): ...@@ -246,16 +252,16 @@ class EnasTrainer(BaseOneShotTrainer):
n_train = len(self.dataset) n_train = len(self.dataset)
split = n_train // 2 split = n_train // 2
indices = list(range(n_train)) indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) train_sampler = SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) valid_sampler = SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset, self.train_loader = DataLoader(self.dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
sampler=train_sampler, sampler=train_sampler,
num_workers=self.workers) num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset, self.valid_loader = DataLoader(self.dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
sampler=valid_sampler, sampler=valid_sampler,
num_workers=self.workers) num_workers=self.workers)
def _train_model(self, epoch): def _train_model(self, epoch):
self.model.train() self.model.train()
...@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer): ...@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer):
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y) reward = self.reward_function(logits, y)
if self.entropy_weight: if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item() reward += self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline) loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight: if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty loss += self.skip_weight * self.controller.sample_skip_penalty
metrics['reward'] = reward metrics['reward'] = reward
metrics['loss'] = loss.item() metrics['loss'] = loss.item()
metrics['ent'] = self.controller.sample_entropy.item() metrics['ent'] = self.controller.sample_entropy.item() # type: ignore
metrics['log_prob'] = self.controller.sample_log_prob.item() metrics['log_prob'] = self.controller.sample_log_prob.item() # type: ignore
metrics['baseline'] = self.baseline metrics['baseline'] = self.baseline
metrics['skip'] = self.controller.sample_skip_penalty metrics['skip'] = self.controller.sample_skip_penalty
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# type: ignore # type: ignore
import logging import logging
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
grad_reg_loss_type=None, grad_reg_loss_params=None, grad_reg_loss_type=None, grad_reg_loss_params=None,
applied_hardware=None, dummy_input=(1, 3, 224, 224), applied_hardware=None, dummy_input=(1, 3, 224, 224),
ref_latency=65.0): ref_latency=65.0):
warnings.warn('ProxylessTrainer is deprecated. Please use strategy.Proxyless instead.', DeprecationWarning)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
# type: ignore
import logging import logging
import random import random
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer): ...@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size=64, workers=4, device=None, log_frequency=None): batch_size=64, workers=4, device=None, log_frequency=None):
warnings.warn('SinglePathTrainer is deprecated. Please use strategy.RandomOneShot instead.', DeprecationWarning)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
"""Experimental version of sampling-based one-shot implementation.""" """Experimental version of sampling-based one-shot implementation."""
from typing import Dict, Any, List from __future__ import annotations
from typing import Any
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
...@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
) )
# turn on automatic optimization because nothing interesting is going on here. # turn on automatic optimization because nothing interesting is going on here.
automatic_optimization = True @property
def automatic_optimization(self) -> bool:
return True
def default_mutation_hooks(self) -> List[MutationHook]: def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions""" """Replace modules with differentiable versions"""
hooks = [ hooks = [
PathSamplingLayer.mutate, PathSamplingLayer.mutate,
...@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller. Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float ctrl_grad_clip : float
Gradient clipping value of controller. Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``,
which is "default" by default.
Otherwise it raises an exception indicating multiple metrics are found.
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _enas_note.format( __doc__ = _enas_note.format(
...@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule):
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
automatic_optimization = False @property
def automatic_optimization(self) -> bool:
return False
def __init__(self, def __init__(self,
inner_module: pl.LightningModule, inner_module: pl.LightningModule,
*, *,
ctrl_kwargs: Dict[str, Any] = None, ctrl_kwargs: dict[str, Any] | None = None,
entropy_weight: float = 1e-4, entropy_weight: float = 1e-4,
skip_weight: float = .8, skip_weight: float = .8,
baseline_decay: float = .999, baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20, ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0, ctrl_grad_clip: float = 0,
mutation_hooks: List[MutationHook] = None): reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks) super().__init__(inner_module, mutation_hooks)
# convert parameter spec to legacy ReinforceField # convert parameter spec to legacy ReinforceField
# this part will be refactored # this part will be refactored
self.nas_fields: List[ReinforceField] = [] self.nas_fields: list[ReinforceField] = []
for name, param_spec in self.search_space_spec().items(): for name, param_spec in self.search_space_spec().items():
if param_spec.chosen_size not in (1, None): if param_spec.chosen_size not in (1, None):
raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.') raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.')
...@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
self.baseline = 0. self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip self.ctrl_grad_clip = ctrl_grad_clip
self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
...@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule):
if source == 'train': if source == 'train':
# step 1: train model params # step 1: train model params
self.resample() self.resample()
self.call_user_optimizers('zero_grad') self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx) loss_and_metrics = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss) self.manual_backward(w_step_loss)
self.call_user_optimizers('step') self.call_weight_optimizers('step')
return loss_and_metrics return loss_and_metrics
if source == 'val': if source == 'val':
# step 2: train ENAS agent # step 2: train ENAS agent
x, y = batch arc_opt = self.architecture_optimizers()
arc_opt = self.architecture_optimizers if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad() arc_opt.zero_grad()
self.resample() self.resample()
with torch.no_grad(): self.model.validation_step(batch, batch_idx)
logits = self.model(x)
# use the default metric of self.model as reward function # use the default metric of self.model as reward function
if len(self.model.metrics) == 1: if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.model.metrics.items())) _, metric = next(iter(self.trainer.callback_metrics.items()))
else: else:
if 'default' not in self.model.metrics.keys(): metric_name = self.reward_metric_name or 'default'
raise KeyError('model.metrics should contain a ``default`` key when' if metric_name not in self.trainer.callback_metrics:
'there are multiple metrics') raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
metric = self.model.metrics['default'] f'found multiple metrics without default: {self.trainer.callback_metrics.keys()}')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
reward = metric(logits, y)
if self.entropy_weight: if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline) rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight: if self.skip_weight:
...@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
with torch.no_grad(): with torch.no_grad():
return self._interpret_controller_sampling_result(self.controller.resample()) return self._interpret_controller_sampling_result(self.controller.resample())
def _interpret_controller_sampling_result(self, sample: Dict[str, int]) -> Dict[str, Any]: def _interpret_controller_sampling_result(self, sample: dict[str, int]) -> dict[str, Any]:
"""Convert ``{label: index}`` to ``{label: name}``""" """Convert ``{label: index}`` to ``{label: name}``"""
space_spec = self.search_space_spec() space_spec = self.search_space_spec()
for key in list(sample.keys()): for key in list(sample.keys()):
......
...@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b ...@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py. When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
""" """
from __future__ import annotations
import warnings import warnings
from typing import Any, List, Optional, Type, Union, Tuple from typing import Any, Type
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy): ...@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy):
self.oneshot_module = oneshot_module self.oneshot_module = oneshot_module
self.oneshot_kwargs = kwargs self.oneshot_kwargs = kwargs
self.model: Optional[BaseOneShotLightningModule] = None self.model: BaseOneShotLightningModule | None = None
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader) \ def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader | list[DataLoader]) \
-> Union[DataLoader, Tuple[DataLoader, DataLoader]]: -> DataLoader | tuple[DataLoader, DataLoader]:
""" """
One-shot strategy typically requires a customized dataloader. One-shot strategy typically requires a customized dataloader.
...@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy): ...@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy):
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.' _reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model: nn.Module = base_model.python_object if not isinstance(base_model.python_object, nn.Module):
if not isinstance(py_model, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason) raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if applied_mutators: if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason) raise ValueError('Mutator is not empty. ' + _reason)
...@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy): ...@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy):
evaluator_module: LightningModule = base_model.evaluator.module evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.set_model(py_model) evaluator_module.set_model(py_model)
self.model: BaseOneShotLightningModule = self.oneshot_module(evaluator_module, **self.oneshot_kwargs) self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloader is None or evaluator.val_dataloaders is None:
raise TypeError('Train or val dataloader is not set.')
dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders) dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders)
if isinstance(dataloader, tuple): if isinstance(dataloader, tuple):
dataloader, val_loader = dataloader dataloader, val_loader = dataloader
...@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy): ...@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy):
else: else:
evaluator.trainer.fit(self.model, dataloader) evaluator.trainer.fit(self.model, dataloader)
def export_top_models(self, top_k: int = 1) -> List[Any]: def export_top_models(self, top_k: int = 1) -> list[Any]:
if self.model is None: if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.') raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1: if top_k != 1:
......
...@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``, ...@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor. which interprets the slice and apply it on a tensor.
""" """
from __future__ import annotations
import operator import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast from typing import Callable, Iterator, TypeVar, Any, Optional, Tuple, Union, List, Dict, Generic, cast
import numpy as np import numpy as np
import torch import torch
...@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic ...@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
for i in range(len(slice_)): for i in range(len(slice_)):
if isinstance(slice_[i], list): if isinstance(slice_[i], list):
# convert list of slices to mask # convert list of slices to mask
mask = np.zeros(shape[i], dtype=np.bool) mask = np.zeros(shape[i], dtype=np.bool) # type: ignore
for sl in slice_[i]: for sl in cast(List[slice], slice_[i]):
mask[sl] = 1 mask[sl] = 1
result.append(mask) result.append(mask)
else: else:
...@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic ...@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
return tuple(result) return tuple(result)
def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_slice, float]]]) -> T: def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice, float]]) -> T:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4]) # slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3} # or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
...@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s ...@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s
# create a mask with weight w # create a mask with weight w
with torch.no_grad(): with torch.no_grad():
mask = zeros_like(weight) mask = zeros_like(weight)
mask[_eliminate_list_slice(weight.shape, sl)] = 1 mask[_eliminate_list_slice(weight.shape, sl)] = 1 # type: ignore
# track gradients here # track gradients here
masks.append((mask * wt)) masks.append(mask * wt) # type: ignore
masks = sum(masks) masks = sum(masks)
return masks * weight return masks * weight # type: ignore
else: else:
# for unweighted case, we slice it directly. # for unweighted case, we slice it directly.
def _do_slice(arr, slice_): def _do_slice(arr, slice_):
return arr[_eliminate_list_slice(arr.shape, slice_)] return arr[_eliminate_list_slice(arr.shape, slice_)] # type: ignore
# sometimes, we don't need slice. # sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster # this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex. # Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=np.bool) dummy_arr = np.zeros(weight.shape, dtype=np.bool) # type: ignore
no_effect = _do_slice(dummy_arr, slice_).shape == dummy_arr.shape no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect: if no_effect:
return weight return weight
...@@ -128,14 +130,14 @@ class Slicable(Generic[T]): ...@@ -128,14 +130,14 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}') raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight self.weight = weight
def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T: def __getitem__(self, index: slice_type | multidim_slice) -> T:
if not isinstance(index, tuple): if not isinstance(index, tuple):
index = (index, ) index = (index, )
index = cast(multidim_slice, index) index = cast(multidim_slice, index)
# Get the dict value in index's leafs # Get the dict value in index's leafs
# There can be at most one dict # There can be at most one dict
leaf_dict: Optional[Dict[int, float]] = None leaf_dict: dict[int, float] | None = None
for maybe_weighted in _iterate_over_multidim_slice(index): for maybe_weighted in _iterate_over_multidim_slice(index):
for d in maybe_weighted.leaf_values(): for d in maybe_weighted.leaf_values():
if isinstance(d, dict): if isinstance(d, dict):
...@@ -166,10 +168,10 @@ class MaybeWeighted: ...@@ -166,10 +168,10 @@ class MaybeWeighted:
""" """
def __init__(self, def __init__(self,
value: Optional[int_or_int_dict] = None, *, value: int_or_int_dict | None = None, *,
lhs: Optional[Union['MaybeWeighted', int]] = None, lhs: 'MaybeWeighted' | int | None = None,
rhs: Optional[Union['MaybeWeighted', int]] = None, rhs: 'MaybeWeighted' | int | None = None,
operation: Optional[Callable[[int, int], int]] = None): operation: Callable[[int_or_int_dict, int_or_int_dict], int_or_int_dict] | None = None):
if operation is None: if operation is None:
if not isinstance(value, (int, dict)): if not isinstance(value, (int, dict)):
raise TypeError(f'Unsupported value type: {type(value)}') raise TypeError(f'Unsupported value type: {type(value)}')
...@@ -178,7 +180,7 @@ class MaybeWeighted: ...@@ -178,7 +180,7 @@ class MaybeWeighted:
self.rhs = rhs self.rhs = rhs
self.operation = operation self.operation = operation
def leaf_values(self) -> Iterator[Dict[int, float]]: def leaf_values(self) -> Iterator[int_or_int_dict]:
"""Iterate over values on leaf nodes.""" """Iterate over values on leaf nodes."""
if self.value is not None: if self.value is not None:
yield self.value yield self.value
...@@ -188,7 +190,7 @@ class MaybeWeighted: ...@@ -188,7 +190,7 @@ class MaybeWeighted:
if isinstance(self.rhs, MaybeWeighted): if isinstance(self.rhs, MaybeWeighted):
yield from self.rhs.leaf_values() yield from self.rhs.leaf_values()
def evaluate(self, value_fn: _value_fn_type = None) -> int: def evaluate(self, value_fn: _value_fn_type = None) -> int_or_int_dict:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``. """Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used. If ``value_fn`` is none, no replacement will happen and the raw value will be used.
""" """
...@@ -200,11 +202,12 @@ class MaybeWeighted: ...@@ -200,11 +202,12 @@ class MaybeWeighted:
if isinstance(self.lhs, MaybeWeighted): if isinstance(self.lhs, MaybeWeighted):
eval_lhs = self.lhs.evaluate(value_fn) eval_lhs = self.lhs.evaluate(value_fn)
else: else:
eval_lhs = self.lhs eval_lhs = cast(int, self.lhs)
if isinstance(self.rhs, MaybeWeighted): if isinstance(self.rhs, MaybeWeighted):
eval_rhs = self.rhs.evaluate(value_fn) eval_rhs = self.rhs.evaluate(value_fn)
else: else:
eval_rhs = self.rhs eval_rhs = cast(int, self.rhs)
assert self.operation is not None
return self.operation(eval_lhs, eval_rhs) return self.operation(eval_lhs, eval_rhs)
def __repr__(self): def __repr__(self):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
# pylint: skip-file # pylint: skip-file
# type: ignore
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__. """This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item. These are merely some components of the algorithm. The complete support is an undergoing work item.
...@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d): ...@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d):
---------- ----------
input_weight : Tensor input_weight : Tensor
the weight to be weighted summed the weight to be weighted summed
masks : List[Tensor] masks : list[Tensor]
weight masks. weight masks.
thresholds : List[float] thresholds : list[float]
thresholds, should have a length of ``len(masks) - 1`` thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float] indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight take a tensor and a threshold as input, and output the weight
......
...@@ -4,19 +4,26 @@ ...@@ -4,19 +4,26 @@
"""Utilities to process the value choice compositions, """Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms.""" in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations
import itertools import itertools
from typing import List, Any, Dict, Tuple, Optional, Union from typing import Any, TypeVar, List, cast
import numpy as np
import torch
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.retiarii.nn.pytorch.api import ChoiceOf, ValueChoiceX
Choice = Any Choice = Any
T = TypeVar('T')
__all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options'] __all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options']
def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, ParameterSpec]: def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
"""Find all leaf nodes in ``value_choices``, """Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``. save them into in the format of ``{label: parameter_spec}``.
""" """
...@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete ...@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete
return result return result
def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str, Choice]) -> Any: def evaluate_value_choice_with_dict(value_choice: ChoiceOf[T], chosen: dict[str, Choice]) -> T:
"""To evaluate a composition of value-choice with a dict, """To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``. with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values, The implementation is two-pass. We first get a list of values,
...@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str ...@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str
return value_choice.evaluate(choice_inner_values) return value_choice.evaluate(choice_inner_values)
def traverse_all_options(value_choice: ValueChoiceX, def traverse_all_options(
weights: Optional[Dict[str, List[float]]] = None) -> List[Union[Tuple[Any, float], Any]]: value_choice: ChoiceOf[T],
weights: dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
) -> list[tuple[T, float]] | list[T]:
"""Traverse all possible computation outcome of a value choice. """Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome. If ``weights`` is not None, it will also compute the probability of each possible outcome.
...@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX, ...@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX,
---------- ----------
value_choice : ValueChoiceX value_choice : ValueChoiceX
The value choice to traverse. The value choice to traverse.
weights : Optional[Dict[str, List[float]]], default = None weights : Optional[dict[str, list[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results, If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability. weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function. Normally, they should sum up to 1, but we will not check them in this function.
Returns Returns
------- -------
List[Union[Tuple[Any, float], Any]] list[Union[tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated. Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight. If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options. Otherwise, it will be a list of options.
""" """
# get a dict of {label: list of tuple of choice and weight} # get a dict of {label: list of tuple of choice and weight}
leafs: Dict[str, List[Tuple[Choice, float]]] = {} leafs: dict[str, list[tuple[T, float]]] = {}
for label, param_spec in dedup_inner_choices([value_choice]).items(): for label, param_spec in dedup_inner_choices([value_choice]).items():
if weights is not None: if weights is not None:
if label not in weights: if label not in weights:
raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}') raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}')
if len(weights[label]) != param_spec.size: if len(weights[label]) != param_spec.size:
raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found') raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found')
leafs[label] = list(zip(param_spec.values, weights[label])) leafs[label] = list(zip(param_spec.values, cast(List[float], weights[label])))
else: else:
# create a dummy weight of zero, in case that weights are not provided. # create a dummy weight of zero, in case that weights are not provided.
leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size))) leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size)))
# result is a dict from a option to its weight # result is a dict from a option to its weight
result: Dict[str, Optional[float]] = {} result: dict[T, float | None] = {}
labels, values = list(leafs.keys()), list(leafs.values()) labels, values = list(leafs.keys()), list(leafs.values())
if not labels: if not labels:
...@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX, ...@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX,
result[eval_res] = chosen_weight result[eval_res] = chosen_weight
if weights is None: if weights is None:
return sorted(result.keys()) return sorted(result.keys()) # type: ignore
else: else:
return sorted(result.items()) return sorted(result.items()) # type: ignore
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Any, Dict, Tuple, Union from __future__ import annotations
from typing import Any
import torch.nn as nn import torch.nn as nn
...@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module): ...@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions. rather than their compositions.
""" """
def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]: def resample(self, memo: dict[str, Any]) -> dict[str, Any]:
""" """
Resample the super-net module. Resample the super-net module.
Parameters Parameters
---------- ----------
memo : Dict[str, Any] memo : dict[str, Any]
Used to ensure the consistency of samples with the same label. Used to ensure the consistency of samples with the same label.
Returns Returns
...@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module): ...@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module):
""" """
raise NotImplementedError() raise NotImplementedError()
def export(self, memo: Dict[str, Any]) -> Dict[str, Any]: def export(self, memo: dict[str, Any]) -> dict[str, Any]:
""" """
Export the final architecture within this module. Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``. It should have the same keys as ``search_space_spec()``.
Parameters Parameters
---------- ----------
memo : Dict[str, Any] memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times. Use memo to avoid the same label gets exported multiple times.
""" """
raise NotImplementedError() raise NotImplementedError()
def search_space_spec(self) -> Dict[str, ParameterSpec]: def search_space_spec(self) -> dict[str, ParameterSpec]:
""" """
Space specification (sample points). Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export. Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
...@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module): ...@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module):
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def mutate(cls, module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> \ def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
Union['BaseSuperNetModule', bool, Tuple['BaseSuperNetModule', bool]]: 'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`. """This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module, The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on. because they usually have specific rules about what kind of modules to operate on.
...@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module): ...@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module):
Returns Returns
------- -------
Union[BaseSuperNetModule, bool, Tuple[BaseSuperNetModule, bool]] Union[BaseSuperNetModule, bool, tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks. The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details. See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
""" """
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import functools import functools
import warnings import warnings
from typing import List, Tuple, Optional, Dict, Any, Union from typing import Any, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options ...@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options
class GumbelSoftmax(nn.Softmax): class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default.""" """Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def __init__(self, dim: Optional[int] = -1) -> None: dim: int
def __init__(self, dim: int = -1) -> None:
super().__init__(dim) super().__init__(dim)
self.tau = 1 self.tau = 1
self.hard = False self.hard = False
...@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Parameters Parameters
---------- ----------
paths : List[Tuple[str, nn.Module]] paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module. Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor alpha : Tensor
Tensor that stores the "learnable" weights. Tensor that stores the "learnable" weights.
...@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Name of the choice. Name of the choice.
""" """
_arch_parameter_names: List[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__() super().__init__()
self.op_names = [] self.op_names = []
if len(alpha) != len(paths): if len(alpha) != len(paths):
...@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
"""Choose the operator with the maximum logit.""" """Choose the operator with the maximum logit."""
if self.label in memo: if self.label in memo:
return {} # nothing new to export return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]} return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def search_space_spec(self): def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ), return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
...@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule):
Name of the choice. Name of the choice.
""" """
_arch_parameter_names: List[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__() super().__init__()
self.n_candidates = n_candidates self.n_candidates = n_candidates
if len(alpha) != n_candidates: if len(alpha) != n_candidates:
...@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
won't be optimized. won't be optimized.
""" """
_arch_parameter_names: List[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None: def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments` # Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict() operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items(): for name, spec in operation.search_space_spec().items():
...@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, self=operation) # bind self operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, self=operation) operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod @staticmethod
def parameters(self, *args, **kwargs): def parameters(module, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs): for _, p in module.named_parameters(*args, **kwargs):
yield p yield p
@staticmethod @staticmethod
def named_parameters(self, *args, **kwargs): def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False) arch = kwargs.pop('arch', False)
for name, p in super(self.__class__, self).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names): if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch: if arch:
yield name, p yield name, p
...@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if not arch: if not arch:
yield name, p yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]: def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample.""" """Differentiable. Do nothing in resample."""
return {} return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]: def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice.""" """Export is also random for each leaf value choice."""
result = {} result = {}
for name, spec in operation.search_space_spec().items(): for name, spec in operation.search_space_spec().items():
if name in result: if name in result:
continue continue
chosen_index = torch.argmax(operation._arch_alpha[name]).item() chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
return result return result
def forward_argument(self, operation: MixedOperation, name: str) -> Union[Dict[Any, float], Any]: def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments: if name in operation.mutable_arguments:
weights = {label: operation._softmax(alpha) for label, alpha in operation._arch_alpha.items()} weights: dict[str, torch.Tensor] = {
label: cast(nn.Module, operation._softmax)(alpha) for label, alpha in cast(dict, operation._arch_alpha).items()
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights)) return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name] return operation.init_arguments[name]
...@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level, ...@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement. which is commonly known as super-kernel (as in channel search), or weight entanglement.
""" """
from __future__ import annotations
import inspect import inspect
import itertools import itertools
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast from typing import Any, Type, TypeVar, cast, Union, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy: ...@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy:
One SamplingStrategy corresponds to one mixed operation. One SamplingStrategy corresponds to one mixed operation.
""" """
def __init__(self, operation: 'MixedOperation', memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None: def __init__(self, operation: 'MixedOperation', memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
"""At init, the sampling policy can prepare basic parameters, """At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation. and store them in operation if they need back propagation.
...@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy: ...@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy:
""" """
pass pass
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]: def resample(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`.""" """The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError() raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]: def export(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`.""" """The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError() raise NotImplementedError()
...@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule): ...@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule):
""" """
bound_type: Type[nn.Module] # defined in subclass bound_type: Type[nn.Module] # defined in subclass
argument_list: List[str] # defined in subclass argument_list: list[str] # defined in subclass
sampling_policy: MixedOperationSamplingPolicy sampling_policy: MixedOperationSamplingPolicy
...@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule): ...@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule):
appended by forward arguments in the ``bound_type``.""" appended by forward arguments in the ``bound_type``."""
raise NotImplementedError() raise NotImplementedError()
def __init__(self, module_kwargs: Dict[str, Any]) -> None: def __init__(self, module_kwargs: dict[str, Any]) -> None:
# Concerned arguments # Concerned arguments
self.mutable_arguments: Dict[str, ValueChoiceX] = {} self.mutable_arguments: dict[str, ValueChoiceX] = {}
# Useful when retrieving arguments without ValueChoice # Useful when retrieving arguments without ValueChoice
self.init_arguments: Dict[str, Any] = {**module_kwargs} self.init_arguments: dict[str, Any] = {**module_kwargs}
self._fill_missing_init_arguments() self._fill_missing_init_arguments()
# get init default # get init default
...@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule): ...@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule):
super_init_kwargs[key] = value super_init_kwargs[key] = value
# get all inner leaf value choices # get all inner leaf value choices
self._space_spec: Dict[str, ParameterSpec] = dedup_inner_choices(self.mutable_arguments.values()) self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
super().__init__(**super_init_kwargs) super().__init__(**super_init_kwargs)
...@@ -156,17 +158,17 @@ class MixedOperation(BaseSuperNetModule): ...@@ -156,17 +158,17 @@ class MixedOperation(BaseSuperNetModule):
"""Find value choice in module's arguments and replace the whole module""" """Find value choice in module's arguments and replace the whole module"""
has_valuechoice = False has_valuechoice = False
if isinstance(module, cls.bound_type) and is_traceable(module): if isinstance(module, cls.bound_type) and is_traceable(module):
for arg in itertools.chain(module.trace_args, module.trace_kwargs.values()): for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX): if isinstance(arg, ValueChoiceX):
has_valuechoice = True has_valuechoice = True
if has_valuechoice: if has_valuechoice:
if module.trace_args: if module.trace_args:
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. ' raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.') 'Please enable ``kw_only`` on nni.trace.')
# save type and kwargs # save type and kwargs
mixed_op = cls(module.trace_kwargs) mixed_op = cls(cast(dict, module.trace_kwargs))
if 'mixed_op_sampling' not in mutate_kwargs: if 'mixed_op_sampling' not in mutate_kwargs:
raise ValueError('Need to sampling policy of mixed op, but not found in `mutate_kwargs`.') raise ValueError('Need to sampling policy of mixed op, but not found in `mutate_kwargs`.')
...@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear): ...@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear):
out_features: int_or_int_dict, out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor: inputs: torch.Tensor) -> torch.Tensor:
in_features = _W(in_features) in_features_ = _W(in_features)
out_features = _W(out_features) out_features_ = _W(out_features)
weight = _S(self.weight)[:out_features] weight = _S(self.weight)[:out_features_]
weight = _S(weight)[:, :in_features] weight = _S(weight)[:, :in_features_]
if self.bias is None: if self.bias is None:
bias = self.bias bias = self.bias
else: else:
bias = _S(self.bias)[:out_features] bias = _S(self.bias)[:out_features_]
return F.linear(inputs, weight, bias) return F.linear(inputs, weight, bias)
...@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
] ]
@staticmethod @staticmethod
def _to_tuple(value: scalar_or_scalar_dict[T]) -> Tuple[T, T]: def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple): if not isinstance(value, tuple):
return (value, value) return (value, value)
return value return value
...@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if any(isinstance(arg, dict) for arg in [stride, dilation, groups]): if any(isinstance(arg, dict) for arg in [stride, dilation, groups]):
raise ValueError('stride, dilation, groups does not support weighted sampling.') raise ValueError('stride, dilation, groups does not support weighted sampling.')
in_channels = _W(in_channels) in_channels_ = _W(in_channels)
out_channels = _W(out_channels) out_channels_ = _W(out_channels)
# slice prefix # slice prefix
# For groups > 1, we use groups to slice input weights # For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels] weight = _S(self.weight)[:out_channels_]
weight = _S(weight)[:, :in_channels // groups] weight = _S(weight)[:, :in_channels_ // groups]
# slice center # slice center
if isinstance(kernel_size, dict): if isinstance(kernel_size, dict):
# If kernel size is a dict, ignore choices in padding.
if isinstance(self.padding, str):
raise ValueError(f'Use "{self.padding}" in padding is not supported.')
padding = self.padding # max padding, must be a tuple padding = self.padding # max padding, must be a tuple
kernel_a, kernel_b = self._to_tuple(kernel_size) kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a, kernel_b = _W(kernel_a), _W(kernel_b) kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a) // 2, (max_kernel_b - kernel_b) // 2 kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a, kernel_b_top:kernel_b_top + kernel_b] weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels] if self.bias is not None else None bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
# The rest parameters only need to be converted to tuple # The rest parameters only need to be converted to tuple
stride = self._to_tuple(stride) stride_ = self._to_tuple(stride)
dilation = self._to_tuple(dilation) dilation_ = self._to_tuple(dilation)
if self.padding_mode != 'zeros': if self.padding_mode != 'zeros':
return F.conv2d(F.pad(inputs, self._reversed_padding_repeated_twice, mode=self.padding_mode), return F.conv2d(F.pad(inputs, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, stride, (0, 0), dilation, groups) weight, bias, stride_, (0, 0), dilation_, groups)
return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups) return F.conv2d(inputs, weight, bias, stride_, cast('int | tuple', padding), dilation_, groups)
class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if num_features < self.num_features: if num_features < self.num_features:
weight = weight[:num_features] weight = weight[:num_features]
bias = bias[:num_features] bias = bias[:num_features]
running_mean = running_mean[:num_features] if running_mean is not None:
running_var = running_var[:num_features] running_mean = running_mean[:num_features]
if running_var is not None:
running_var = running_var[:num_features]
if self.training: if self.training:
bn_training = True bn_training = True
else: else:
bn_training = (self.running_mean is None) and (self.running_var is None) bn_training = (running_mean is None) and (running_var is None)
return F.batch_norm( return F.batch_norm(
inputs, inputs,
...@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def _to_proj_slice(self, embed_dim: _W) -> List[slice]: def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
# slice three parts, corresponding to q, k, v respectively # slice three parts, corresponding to q, k, v respectively
return [ return [
slice(embed_dim), slice(embed_dim),
...@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def forward_with_args( def forward_with_args(
self, self,
embed_dim: int_or_int_dict, num_heads: int, embed_dim: int_or_int_dict, num_heads: int,
kdim: Optional[int_or_int_dict], vdim: Optional[int_or_int_dict], kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
dropout: float, dropout: float,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None, key_padding_mask: torch.Tensor | None = None,
need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None need_weights: bool = True, attn_mask: torch.Tensor | None = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
if any(isinstance(arg, dict) for arg in [num_heads, dropout]): if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError('num_heads, dropout do not support weighted sampling.') raise ValueError('num_heads, dropout do not support weighted sampling.')
...@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
else: else:
used_embed_dim = embed_dim used_embed_dim = embed_dim
embed_dim = _W(embed_dim) embed_dim_ = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together # in projection weights & biases has q, k, v weights concatenated together
in_proj_bias: Optional[Tensor] = None in_proj_bias: Tensor | None = None
in_proj_weight: Optional[Tensor] = None in_proj_weight: Tensor | None = None
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)] in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim_)]
if self.in_proj_weight is not None: if self.in_proj_weight is not None:
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim] in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim_), :embed_dim_]
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim_] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim_] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim] out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim_, :embed_dim_]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim_] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim: if not qkv_same_embed_dim:
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim] q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim_, :embed_dim_]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim] k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim_]
k_proj = _S(k_proj)[:, :_W(kdim)] k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim] v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
v_proj = _S(v_proj)[:, :_W(vdim)] v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch # The rest part is basically same as pytorch
...@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return attn_output, attn_output_weights return attn_output, attn_output_weights
NATIVE_MIXED_OPERATIONS: List[Type[MixedOperation]] = [ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear, MixedLinear,
MixedConv2d, MixedConv2d,
MixedBatchNorm2d, MixedBatchNorm2d,
......
...@@ -9,7 +9,9 @@ The support remains limited. Known limitations include: ...@@ -9,7 +9,9 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor. - The code contains duplicates. Needs refactor.
""" """
from typing import List, Tuple, Optional, cast from __future__ import annotations
from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label) super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3) self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# like sampling-based methods, it has a ``_sampled``. # like sampling-based methods, it has a ``_sampled``.
self._sampled: Optional[str] = None self._sampled: str | None = None
self._sample_idx: Optional[int] = None self._sample_idx: int | None = None
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs): def run_function(ops, active_id, **kwargs):
...@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput): ...@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput):
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label) super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3) self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
self._sampled: Optional[int] = None self._sampled: int | None = None
def forward(self, inputs): def forward(self, inputs):
def run_function(active_sample): def run_function(active_sample):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import random import random
from typing import Optional, List, Tuple, Union, Dict, Any from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule): ...@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule):
Name of the choice. Name of the choice.
""" """
def __init__(self, paths: List[Tuple[str, nn.Module]], label: str): def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
super().__init__() super().__init__()
self.op_names = [] self.op_names = []
for name, module in paths: for name, module in paths:
self.add_module(name, module) self.add_module(name, module)
self.op_names.append(name) self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.' assert self.op_names, 'There has to be at least one op to choose from.'
self._sampled: Optional[Union[List[str], str]] = None # sampled can be either a list of indices or an index self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
self.label = label self.label = label
def resample(self, memo): def resample(self, memo):
...@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule):
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self._sampled: Optional[Union[List[int], int]] = None self._sampled: list[int] | int | None = None
self.label = label self.label = label
def _random_choose_n(self): def _random_choose_n(self):
...@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
We sample the leaf nodes, and composits them into the values on arguments. We sample the leaf nodes, and composits them into the values on arguments.
""" """
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None: def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments` # Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None self._sampled: dict[str, Any] | None = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]: def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Random sample for each leaf value choice.""" """Random sample for each leaf value choice."""
result = {} result = {}
space_spec = operation.search_space_spec() space_spec = operation.search_space_spec()
...@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result return result
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]: def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice.""" """Export is also random for each leaf value choice."""
result = {} result = {}
space_spec = operation.search_space_spec() space_spec = operation.search_space_spec()
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import cast
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Dataset
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice from nni.nas.pytorch.mutables import InputChoice, LayerChoice
...@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None): ...@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
Returns Returns
------- -------
List[Tuple[str, nn.Module]] list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules) return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
...@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None): ...@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
Returns Returns
------- -------
List[Tuple[str, nn.Module]] list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules) return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
...@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader): ...@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader):
Example Example
-------- --------
Fit your dataloaders into a parallel one. Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader) >>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader. Then you can use the ``para_loader`` as a normal training loader.
""" """
def __init__(self, train_dataloader, val_dataloader): def __init__(self, train_dataloader: DataLoader, val_dataloader: DataLoader | list[DataLoader]):
self.train_loader = train_dataloader if isinstance(val_dataloader, list):
self.val_loader = val_dataloader raise TypeError('Validation dataloader of type list is not supported.')
self.train_loader: DataLoader = train_dataloader
self.val_loader: DataLoader = val_dataloader
self.equal_len = len(train_dataloader) == len(val_dataloader) self.equal_len = len(train_dataloader) == len(val_dataloader)
self.train_longer = len(train_dataloader) > len(val_dataloader) self.train_longer = len(train_dataloader) > len(val_dataloader)
super().__init__(None) super().__init__(cast(Dataset, None))
def __iter__(self): def __iter__(self):
self.train_iter = iter(self.train_loader) self.train_iter = iter(self.train_loader)
...@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader): ...@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader):
Example Example
-------- --------
Fit your dataloaders into a concatenated one. Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer) >>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader. Then you can use the ``concat_loader`` as a normal training loader.
""" """
def __init__(self, train_dataloader, val_dataloader): def __init__(self, train_dataloader: DataLoader, val_dataloader: DataLoader | list[DataLoader]):
self.train_loader = train_dataloader if isinstance(val_dataloader, list):
self.val_loader = val_dataloader raise TypeError('Validation dataloader of type list is not supported.')
super().__init__(None) self.train_loader: DataLoader = train_dataloader
self.val_loader: DataLoader = val_dataloader
super().__init__(cast(Dataset, None))
def __iter__(self): def __iter__(self):
self.cur_iter = iter(self.train_loader) self.cur_iter = iter(self.train_loader)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"nni/retiarii/execution/cgo_engine.py", "nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer", "nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo", "nni/retiarii/evaluator/pytorch/cgo",
"nni/retiarii/oneshot",
"nni/smartparam.py", "nni/smartparam.py",
"nni/tools/annotation", "nni/tools/annotation",
"nni/tools/gpu_tool", "nni/tools/gpu_tool",
......
...@@ -130,13 +130,16 @@ def test_fit_api(): ...@@ -130,13 +130,16 @@ def test_fit_api():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform) train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform) test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), def lightning(): return pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.1, # for faster training val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
progress_bar_refresh_rate=progress_bar_refresh_rate) max_epochs=1, limit_train_batches=0.1, # for faster training
lightning.fit(lambda: MNISTModel()) progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning.fit(MNISTModel) # Lightning will have some cache in models / trainers,
lightning.fit(MNISTModel()) # which is problematic if we call fit multiple times.
lightning().fit(lambda: MNISTModel())
lightning().fit(MNISTModel)
lightning().fit(MNISTModel())
_reset() _reset()
......
...@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit ...@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.strategy import BaseStrategy
class DepthwiseSeparableConv(nn.Module): class DepthwiseSeparableConv(nn.Module):
...@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True): ...@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True):
] ]
for (base_model, evaluator), support_or_not in to_test: for (base_model, evaluator), support_or_not in to_test:
print('Testing:', type(strategy_).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not) if isinstance(strategy_, BaseStrategy):
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy_) strategy = strategy_
else:
strategy = strategy_(base_model, evaluator)
print('Testing:', type(strategy).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not)
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy)
config = RetiariiExeConfig() config = RetiariiExeConfig()
config.execution_engine = 'oneshot' config.execution_engine = 'oneshot'
...@@ -263,7 +268,12 @@ def test_proxyless(): ...@@ -263,7 +268,12 @@ def test_proxyless():
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas(): def test_enas():
_test_strategy(strategy.ENAS()) def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet):
return strategy.ENAS(reward_metric_name='val_mse')
return strategy.ENAS(reward_metric_name='val_acc')
_test_strategy(strategy_fn)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment