"...composable_kernel_rocm.git" did not exist on "ccaea50e46f1294163a302270c6b28333503156a"
Unverified Commit 8d5f643c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Hyper-parameter Choice in Retiarii (#4609)

parent ba771871
......@@ -5,18 +5,19 @@ import math
import itertools
import operator
import warnings
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar, Sequence
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'ModelParameterChoice', 'Placeholder', 'ChosenInputs']
class LayerChoice(Mutable):
......@@ -870,7 +871,6 @@ class ValueChoice(ValueChoiceX, Mutable):
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self._accessor = []
@property
def label(self):
......@@ -906,6 +906,149 @@ class ValueChoice(ValueChoiceX, Mutable):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
ValueType = TypeVar('ValueType')
class ModelParameterChoice:
"""ModelParameterChoice chooses one hyper-parameter from ``candidates``.
.. attention::
This API is internal, and does not guarantee forward-compatibility.
It's quite similar to :class:`ValueChoice`, but unlike :class:`ValueChoice`,
it always returns a fixed value, even at the construction of base model.
This makes it highly flexible (e.g., can be used in for-loop, if-condition, as argument of any function). For example: ::
self.has_auxiliary_head = ModelParameterChoice([False, True])
# this will raise error if you use `ValueChoice`
if self.has_auxiliary_head is True: # or self.has_auxiliary_head
self.auxiliary_head = Head()
else:
self.auxiliary_head = None
print(type(self.has_auxiliary_head)) # <class 'bool'>
The working mechanism of :class:`ModelParameterChoice` is that, it registers itself
in the ``model_wrapper``, as a hyper-parameter of the model, and then returns the value specified with ``default``.
At base model construction, the default value will be used (as a mocked hyper-parameter).
In trial, the hyper-parameter selected by strategy will be used.
Although flexible, we still recommend using :class:`ValueChoice` in favor of :class:`ModelParameterChoice`,
because information are lost when using :class:`ModelParameterChoice` in exchange of its flexibility,
making it incompatible with one-shot strategies and non-python execution engines.
.. warning::
:class:`ModelParameterChoice` can NOT be nested.
.. tip::
Although called :class:`ModelParameterChoice`, it's meant to tune hyper-parameter of architecture.
It's NOT used to tune model-training hyper-parameters like ``learning_rate``.
If you need to tune ``learning_rate``, please use :class:`ValueChoice` on arguments of :class:`nni.retiarii.Evaluator`.
Parameters
----------
candidates : list of any
List of values to choose from.
prior : list of float
Prior distribution to sample from. Currently has no effect.
default : Callable[[List[Any]], Any] or Any
Function that selects one from ``candidates``, or a candidate.
Use :meth:`ModelParameterChoice.FIRST` or :meth:`ModelParameterChoice.LAST` to take the first or last item.
Default: :meth:`ModelParameterChoice.FIRST`
label : str
Identifier of the value choice.
Warnings
--------
:class:`ModelParameterChoice` is incompatible with one-shot strategies and non-python execution engines.
Sometimes, the same search space implemented **without** :class:`ModelParameterChoice` can be simpler, and explored
with more types of search strategies. For example, the following usages are equivalent: ::
# with ModelParameterChoice
depth = nn.ModelParameterChoice(list(range(3, 10)))
blocks = []
for i in range(depth):
blocks.append(Block())
# w/o HyperParmaeterChoice
blocks = Repeat(Block(), (3, 9))
Examples
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
# FIXME: fix signature in docs
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: List[ValueType], *,
prior: Optional[List[float]] = None,
default: Union[Callable[[List[ValueType]], ValueType], ValueType] = None,
label: Optional[str] = None) -> ValueType:
# Actually, creating a `ModelParameterChoice` never creates one.
# It always return a fixed value, and register a ParameterSpec
if default is None:
default = cls.FIRST
try:
return cls.create_fixed_module(candidates, label=label)
except NoContextError:
return cls.create_default(candidates, default, label)
@staticmethod
def create_default(candidates: List[ValueType],
default: Union[Callable[[List[ValueType]], ValueType], ValueType],
label: Optional[str]) -> ValueType:
if default not in candidates:
# could be callable
try:
default = default(candidates)
except TypeError as e:
if 'not callable' in str(e):
raise TypeError("`default` is not in `candidates`, and it's also not callable.")
raise
label = generate_new_label(label)
parameter_spec = ParameterSpec(
label, # name
'choice', # TODO: support more types
candidates, # value
(label,), # we don't have nested now
True, # yes, categorical
)
# there could be duplicates. Dedup is done in mutator
ModelNamespace.current_context().parameter_specs.append(parameter_spec)
return default
@classmethod
def create_fixed_module(cls, candidates: List[ValueType], *, label: Optional[str] = None, **kwargs) -> ValueType:
# same as ValueChoice
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
@staticmethod
def FIRST(sequence: Sequence[ValueType]) -> ValueType:
"""Get the first item of sequence. Useful in ``default`` argument."""
return sequence[0]
@staticmethod
def LAST(sequence: Sequence[ValueType]) -> ValueType:
"""Get the last item of sequence. Useful in ``default`` argument."""
return sequence[-1]
@basic_unit
class Placeholder(nn.Module):
"""
......
......@@ -11,7 +11,7 @@ from nni.common.serializer import is_traceable
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid
from nni.retiarii.utils import ModelNamespace, uid
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import NasBench101Cell, NasBench101Mutator
......@@ -285,6 +285,13 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
else:
model.python_init_params = {}
# hyper-parameter choice
namespace: ModelNamespace = pytorch_model._model_namespace
for param_spec in namespace.parameter_specs:
assert param_spec.categorical and param_spec.type == 'choice'
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
node.label = param_spec.name
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
......
......@@ -120,7 +120,8 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
with ModelNamespace():
self._model_namespace = ModelNamespace()
with self._model_namespace:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
......
......@@ -9,6 +9,8 @@ from contextlib import contextmanager
from typing import Any, List, Dict
from pathlib import Path
from nni.common.hpo_utils import ParameterSpec
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace']
......@@ -111,43 +113,78 @@ class ContextStack:
class ModelNamespace:
"""
To create an individual namespace for models to enable automatic numbering.
To create an individual namespace for models:
1. to enable automatic numbering;
2. to trace general information (like creation of hyper-parameters) of model.
A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
"""
def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
# for example, key: "model_wrapper"
self.key = key
# the "path" of current name
# By default, it's ``[]``
# If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
# See ``__enter__``.
self.name_path: List[int] = []
# parameter specs.
# Currently only used trace calls of ModelParameterChoice.
self.parameter_specs: List[ParameterSpec] = []
def __enter__(self):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try:
current_context = ContextStack.top(self.key)
next_uid = uid(self._simple_name(self.key, current_context))
ContextStack.push(self.key, current_context + [next_uid])
reset_uid(self._simple_name(self.key, current_context + [next_uid]))
parent_context: 'ModelNamespace' = ModelNamespace.current_context(self.key)
next_uid = uid(parent_context._simple_name())
self.name_path = parent_context.name_path + [next_uid]
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
except NoContextError:
ContextStack.push(self.key, [])
reset_uid(self._simple_name(self.key, []))
# not found, no existing namespace
self.name_path = []
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
def __exit__(self, *args, **kwargs):
ContextStack.pop(self.key)
def _simple_name(self) -> str:
return self.key + ''.join(['_' + str(k) for k in self.name_path])
def __repr__(self):
return f'ModelNamespace(name={self._simple_name()}, num_specs={len(self.parameter_specs)})'
# Access the current context in the model #
@staticmethod
def current_context(key: str = _DEFAULT_MODEL_NAMESPACE) -> 'ModelNamespace':
"""Get the current context in key."""
try:
return ContextStack.top(key)
except NoContextError:
raise NoContextError('ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.')
@staticmethod
def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
"""Get the next label for API calls, with automatic numbering."""
try:
current_context = ContextStack.top(key)
except NoContextError:
# fallback to use "default" namespace
return ModelNamespace._simple_name('default', [uid()])
next_uid = uid(ModelNamespace._simple_name(key, current_context))
return ModelNamespace._simple_name(key, current_context + [next_uid])
# it won't be registered
warnings.warn('ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
'Some features might not work. This will be an error in future releases.', RuntimeWarning)
current_context = ModelNamespace('default')
@staticmethod
def _simple_name(key: str, lst: List[Any]) -> str:
return key + ''.join(['_' + str(k) for k in lst])
next_uid = uid(current_context._simple_name())
return current_context._simple_name() + '_' + str(next_uid)
def get_current_context(key: str) -> Any:
......
......@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice
from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack, original_state_dict_hooks
from nni.retiarii.utils import ContextStack, NoContextError, original_state_dict_hooks
class EnumerateSampler(Sampler):
......@@ -849,6 +849,65 @@ class Python(GraphIR):
@unittest.skip
def test_valuechoice_getitem_functional_expression(self): ...
def test_hyperparameter_choice(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.aux = nn.ModelParameterChoice([False, True])
def forward(self, x):
return x
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
sampler = EnumerateSampler()
model1 = _apply_all_mutators(model, mutators, sampler)
model2 = _apply_all_mutators(model, mutators, sampler)
self.assertEqual(self._get_converted_pytorch_model(model1).aux, False)
self.assertEqual(self._get_converted_pytorch_model(model2).aux, True)
def test_hyperparameter_choice_parameter(self):
class Inner(nn.Module):
def __init__(self):
super().__init__()
self.aux = torch.nn.Parameter(
torch.zeros(1, nn.ModelParameterChoice([64, 128, 256], label='a'), 3, 3)
)
def forward(self):
return self.aux
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.choice = nn.ModelParameterChoice([64, 128, 256], label='a')
self.inner = Inner()
def forward(self):
param = self.inner()
assert param.size(1) == self.choice
return param
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
sampler = RandomSampler()
result_pool = set()
for _ in range(20):
model = _apply_all_mutators(model, mutators, sampler)
result = self._get_converted_pytorch_model(model)()
result_pool.add(result.size(1))
self.assertSetEqual(result_pool, {64, 128, 256})
def test_hyperparameter_choice_no_model_wrapper(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.choice = nn.ModelParameterChoice([64, 128, 256], label='a')
with self.assertRaises(NoContextError):
model = Net()
def test_cell_loose_end(self):
@model_wrapper
class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment