Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import inspect import inspect
import os import os
import warnings import warnings
from typing import Any, TypeVar, Union from typing import Any, TypeVar, Type
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper', __all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
...@@ -48,7 +48,7 @@ def serialize_cls(cls): ...@@ -48,7 +48,7 @@ def serialize_cls(cls):
return trace(cls) return trace(cls)
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
""" """
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it. To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
...@@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: ...@@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
return cls return cls
import torch.nn as nn import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls) cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls) _torchscript_patch(cls)
return cls return cls
def model_wrapper(cls: T) -> Union[T, Traceable]: def model_wrapper(cls: T) -> T:
""" """
Wrap the base model (search space). For example, Wrap the base model (search space). For example,
...@@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]: ...@@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
return cls return cls
import torch.nn as nn import torch.nn as nn
assert issubclass(cls, nn.Module) assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info # subclass can still use trace info
wrapper = trace(cls, inheritable=True) wrapper = trace(cls, inheritable=True)
...@@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool: ...@@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False) return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T, rewrap: str) -> bool: def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None wrapped = None
if is_model_wrapped(cls): if is_model_wrapped(cls):
wrapped = 'model_wrapper' wrapped = 'model_wrapper'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# This file might cause import error for those who didn't install RL-related dependencies # This file might cause import error for those who didn't install RL-related dependencies
import logging import logging
import threading import threading
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import Tuple
import gym import gym
import numpy as np import numpy as np
import tianshou import tianshou
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from gym import spaces from gym import spaces
from tianshou.data import to_torch from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker from tianshou.env.worker import EnvWorker
from nni.typehint import TypedDict
from .utils import get_targeted_model from .utils import get_targeted_model
from ..graph import ModelStatus from ..graph import ModelStatus
from ..execution import submit_models, wait_models from ..execution import submit_models, wait_models
...@@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker): ...@@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker):
self.pool.terminate() self.pool.terminate()
return self.env.close() return self.env.close()
class ObservationType(TypedDict):
action_history: np.ndarray
cur_step: int
action_dim: int
class ModelEvaluationEnv(gym.Env):
class ModelEvaluationEnv(gym.Env[ObservationType, int]):
def __init__(self, base_model, mutators, search_space): def __init__(self, base_model, mutators, search_space):
self.base_model = base_model self.base_model = base_model
self.mutators = mutators self.mutators = mutators
...@@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env): ...@@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env):
def action_space(self): def action_space(self):
return spaces.Discrete(self.action_dim) return spaces.Discrete(self.action_dim)
def reset(self): def reset(self) -> ObservationType:
self.action_history = np.zeros(self.num_steps, dtype=np.int32) self.action_history = np.zeros(self.num_steps, dtype=np.int32)
self.cur_step = 0 self.cur_step = 0
self.sample = {} self.sample = {}
...@@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env): ...@@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env):
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) 'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
} }
def step(self, action): def step(self, action: int) -> Tuple[ObservationType, float, bool, dict]:
cur_key = self.ss_keys[self.cur_step] cur_key = self.ss_keys[self.cur_step]
assert action < len(self.search_space[cur_key]), \ assert action < len(self.search_space[cur_key]), \
f'Current action {action} out of range {self.search_space[cur_key]}.' f'Current action {action} out of range {self.search_space[cur_key]}.'
self.action_history[self.cur_step] = action self.action_history[self.cur_step] = action
self.sample[cur_key] = self.search_space[cur_key][action] self.sample[cur_key] = self.search_space[cur_key][action]
self.cur_step += 1 self.cur_step += 1
obs = { obs: ObservationType = {
'action_history': self.action_history, 'action_history': self.action_history,
'cur_step': self.cur_step, 'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \ 'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
...@@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env): ...@@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env):
wait_models(model) wait_models(model)
if model.status == ModelStatus.Failed: if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {} return self.reset(), 0., False, {}
rew = model.metric rew = float(model.metric)
_logger.info(f'Model metric received as reward: {rew}') _logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {} return obs, rew, True, {}
else: else:
...@@ -147,7 +159,7 @@ class Preprocessor(nn.Module): ...@@ -147,7 +159,7 @@ class Preprocessor(nn.Module):
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True) self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
def forward(self, obs): def forward(self, obs):
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token seq = F.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP. # end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq = self.embedding(seq.long()) seq = self.embedding(seq.long())
feature, _ = self.rnn(seq) feature, _ = self.rnn(seq)
...@@ -167,7 +179,7 @@ class Actor(nn.Module): ...@@ -167,7 +179,7 @@ class Actor(nn.Module):
# to take care of choices with different number of options # to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1) mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf') out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None) return F.softmax(out, dim=-1), kwargs.get('state', None)
class Critic(nn.Module): class Critic(nn.Module):
......
...@@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC): ...@@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC):
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None: def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass pass
def export_top_models(self) -> List[Any]: def export_top_models(self, top_k: int) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.') raise NotImplementedError('"export_top_models" is not implemented.')
...@@ -6,7 +6,7 @@ import itertools ...@@ -6,7 +6,7 @@ import itertools
import logging import logging
import random import random
import time import time
from typing import Any, Dict, List from typing import Any, Dict, List, Sequence, Optional
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
...@@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500 ...@@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history = set() history = set()
search_space_values = copy.deepcopy(list(search_space.values())) search_space_values = copy.deepcopy(list(search_space.values()))
while True: while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries): for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values] selected = [random.choice(v) for v in search_space_values]
if not dedup: if not dedup:
...@@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500 ...@@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
if retry_count + 1 == retries: if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.') _logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)} yield {key: value for key, value in zip(keys, selected)}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
from typing import Optional, Callable from typing import Optional, Callable
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
import time import time
from typing import Optional
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
...@@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__) ...@@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__)
class TPESampler(Sampler): class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'): def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None self.cur_sample: Optional[dict] = None
self.index = None self.index: Optional[int] = None
self.total_parameters = {} self.total_parameters = {}
def update_sample_space(self, sample_space): def update_sample_space(self, sample_space):
...@@ -34,6 +35,7 @@ class TPESampler(Sampler): ...@@ -34,6 +35,7 @@ class TPESampler(Sampler):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result) self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index): def choice(self, candidates, mutator, model, index):
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
chosen = self.cur_sample[str(self.index)] chosen = self.cur_sample[str(self.index)]
self.index += 1 self.index += 1
return chosen return chosen
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import collections import collections
import logging import logging
from typing import Dict, Any, List from typing import Dict, Any, List
......
...@@ -25,4 +25,6 @@ if __name__ == '__main__': ...@@ -25,4 +25,6 @@ if __name__ == '__main__':
elif args.exec == 'benchmark': elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine from .execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph() engine.trial_execute_graph()
...@@ -6,7 +6,7 @@ import itertools ...@@ -6,7 +6,7 @@ import itertools
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Dict from typing import Any, List, Dict, cast
from pathlib import Path from pathlib import Path
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
...@@ -41,9 +41,10 @@ def get_module_name(cls_or_func): ...@@ -41,9 +41,10 @@ def get_module_name(cls_or_func):
if module_name == '__main__': if module_name == '__main__':
# infer the module name with inspect # infer the module name with inspect
for frm in inspect.stack(): for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__': module = inspect.getmodule(frm[0])
if module is not None and module.__name__ == '__main__':
# main module found # main module found
main_file_path = Path(inspect.getsourcefile(frm[0])) main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
if not Path().samefile(main_file_path.parent): if not Path().samefile(main_file_path.parent):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, ' raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
...@@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any): ...@@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any):
supernet_style_state_dict = model.state_dict() supernet_style_state_dict = model.state_dict()
""" """
import torch.utils.hooks
import torch.nn as nn import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.' assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
...@@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any): ...@@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any):
raise KeyError(f'"{src}" not in state dict, but found in mapping.') raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result) destination.update(result)
hooks: List[torch.utils.hooks.RemovableHandle] = []
try: try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook)) hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook)) hooks.append(model._register_state_dict_hook(state_dict_hook))
yield yield
......
...@@ -6,7 +6,7 @@ Types for static checking. ...@@ -6,7 +6,7 @@ Types for static checking.
""" """
__all__ = [ __all__ = [
'Literal', 'Literal', 'TypedDict',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord', 'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
] ]
......
...@@ -64,6 +64,9 @@ stages: ...@@ -64,6 +64,9 @@ stages:
python -m pip install "typing-extensions>=3.10" python -m pip install "typing-extensions>=3.10"
displayName: Resolve dependency version displayName: Resolve dependency version
- script: python test/vso_tools/trigger_import.py
displayName: Trigger import
- script: | - script: |
python -m pylint --rcfile pylintrc nni python -m pylint --rcfile pylintrc nni
displayName: pylint displayName: pylint
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
"nni/algorithms", "nni/algorithms",
"nni/common/device.py", "nni/common/device.py",
"nni/common/graph_utils.py", "nni/common/graph_utils.py",
"nni/common/serializer.py",
"nni/compression", "nni/compression",
"nni/nas", "nni/nas/tensorflow",
"nni/retiarii", "nni/nas/pytorch",
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
"nni/retiarii/oneshot",
"nni/smartparam.py", "nni/smartparam.py",
"nni/tools/annotation", "nni/tools/annotation",
"nni/tools/gpu_tool", "nni/tools/gpu_tool",
...@@ -14,5 +17,6 @@ ...@@ -14,5 +17,6 @@
"nni/tools/nnictl", "nni/tools/nnictl",
"nni/tools/trial_tool" "nni/tools/trial_tool"
], ],
"reportMissingImports": false "reportMissingImports": false,
"reportPrivateImportUsage": false
} }
...@@ -4,4 +4,5 @@ filterwarnings = ...@@ -4,4 +4,5 @@ filterwarnings =
ignore:Using key to access the identifier of:DeprecationWarning ignore:Using key to access the identifier of:DeprecationWarning
ignore:layer_choice.choices is deprecated.:DeprecationWarning ignore:layer_choice.choices is deprecated.:DeprecationWarning
ignore:The truth value of an empty array is ambiguous.:DeprecationWarning ignore:The truth value of an empty array is ambiguous.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
ignore:nni.retiarii.serialize is deprecated and will be removed in future release.:DeprecationWarning ignore:nni.retiarii.serialize is deprecated and will be removed in future release.:DeprecationWarning
...@@ -36,5 +36,9 @@ ...@@ -36,5 +36,9 @@
{"head": ["conv2", null], "tail": ["pool2", null]}, {"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]} {"head": ["pool2", null], "tail": ["_outputs", 0]}
] ]
},
"_evaluator": {
"type": "DebugEvaluator"
} }
} }
...@@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script ...@@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.graph import DebugEvaluator
from nni.retiarii.integration import RetiariiAdvisor from nni.retiarii.integration import RetiariiAdvisor
...@@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase): ...@@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase):
'edges': [] 'edges': []
} }
}) })
model.evaluator = DebugEvaluator()
model.python_class = object model.python_class = object
submit_models(model, model) submit_models(model, model)
......
"""Trigger import of some modules to write some caches,
so that static analysis (e.g., pyright) can know the type."""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
import nni
import nni.retiarii.nn.pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment