Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
......@@ -4,9 +4,9 @@
import inspect
import os
import warnings
from typing import Any, TypeVar, Union
from typing import Any, TypeVar, Type
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
......@@ -48,7 +48,7 @@ def serialize_cls(cls):
return trace(cls)
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
......@@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls)
return cls
def model_wrapper(cls: T) -> Union[T, Traceable]:
def model_wrapper(cls: T) -> T:
"""
Wrap the base model (search space). For example,
......@@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module)
assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
......@@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T, rewrap: str) -> bool:
def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# This file might cause import error for those who didn't install RL-related dependencies
import logging
import threading
from multiprocessing.pool import ThreadPool
from typing import Tuple
import gym
import numpy as np
import tianshou
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym import spaces
from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker
from nni.typehint import TypedDict
from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
......@@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker):
self.pool.terminate()
return self.env.close()
class ObservationType(TypedDict):
action_history: np.ndarray
cur_step: int
action_dim: int
class ModelEvaluationEnv(gym.Env):
class ModelEvaluationEnv(gym.Env[ObservationType, int]):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
self.mutators = mutators
......@@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env):
def action_space(self):
return spaces.Discrete(self.action_dim)
def reset(self):
def reset(self) -> ObservationType:
self.action_history = np.zeros(self.num_steps, dtype=np.int32)
self.cur_step = 0
self.sample = {}
......@@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env):
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
}
def step(self, action):
def step(self, action: int) -> Tuple[ObservationType, float, bool, dict]:
cur_key = self.ss_keys[self.cur_step]
assert action < len(self.search_space[cur_key]), \
f'Current action {action} out of range {self.search_space[cur_key]}.'
self.action_history[self.cur_step] = action
self.sample[cur_key] = self.search_space[cur_key][action]
self.cur_step += 1
obs = {
obs: ObservationType = {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
......@@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env):
wait_models(model)
if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {}
rew = model.metric
rew = float(model.metric)
_logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {}
else:
......@@ -147,7 +159,7 @@ class Preprocessor(nn.Module):
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
def forward(self, obs):
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
seq = F.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq = self.embedding(seq.long())
feature, _ = self.rnn(seq)
......@@ -167,7 +179,7 @@ class Actor(nn.Module):
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
return F.softmax(out, dim=-1), kwargs.get('state', None)
class Critic(nn.Module):
......
......@@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC):
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
def export_top_models(self) -> List[Any]:
def export_top_models(self, top_k: int) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.')
......@@ -6,7 +6,7 @@ import itertools
import logging
import random
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence, Optional
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy
......@@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
......@@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional, Callable
......
......@@ -3,6 +3,7 @@
import logging
import time
from typing import Optional
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
......@@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None
self.index = None
self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
......@@ -34,6 +35,7 @@ class TPESampler(Sampler):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import Dict, Any, List
......
......@@ -25,4 +25,6 @@ if __name__ == '__main__':
elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()
......@@ -6,7 +6,7 @@ import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict
from typing import Any, List, Dict, cast
from pathlib import Path
from nni.common.hpo_utils import ParameterSpec
......@@ -41,9 +41,10 @@ def get_module_name(cls_or_func):
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__':
module = inspect.getmodule(frm[0])
if module is not None and module.__name__ == '__main__':
# main module found
main_file_path = Path(inspect.getsourcefile(frm[0]))
main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
if not Path().samefile(main_file_path.parent):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
......@@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any):
supernet_style_state_dict = model.state_dict()
"""
import torch.utils.hooks
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
......@@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any):
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
hooks: List[torch.utils.hooks.RemovableHandle] = []
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
......
......@@ -6,7 +6,7 @@ Types for static checking.
"""
__all__ = [
'Literal',
'Literal', 'TypedDict',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
]
......
......@@ -64,6 +64,9 @@ stages:
python -m pip install "typing-extensions>=3.10"
displayName: Resolve dependency version
- script: python test/vso_tools/trigger_import.py
displayName: Trigger import
- script: |
python -m pylint --rcfile pylintrc nni
displayName: pylint
......
......@@ -3,10 +3,13 @@
"nni/algorithms",
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/common/serializer.py",
"nni/compression",
"nni/nas",
"nni/retiarii",
"nni/nas/tensorflow",
"nni/nas/pytorch",
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
"nni/retiarii/oneshot",
"nni/smartparam.py",
"nni/tools/annotation",
"nni/tools/gpu_tool",
......@@ -14,5 +17,6 @@
"nni/tools/nnictl",
"nni/tools/trial_tool"
],
"reportMissingImports": false
"reportMissingImports": false,
"reportPrivateImportUsage": false
}
......@@ -4,4 +4,5 @@ filterwarnings =
ignore:Using key to access the identifier of:DeprecationWarning
ignore:layer_choice.choices is deprecated.:DeprecationWarning
ignore:The truth value of an empty array is ambiguous.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
ignore:nni.retiarii.serialize is deprecated and will be removed in future release.:DeprecationWarning
......@@ -36,5 +36,9 @@
{"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]}
]
},
"_evaluator": {
"type": "DebugEvaluator"
}
}
......@@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.graph import DebugEvaluator
from nni.retiarii.integration import RetiariiAdvisor
......@@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase):
'edges': []
}
})
model.evaluator = DebugEvaluator()
model.python_class = object
submit_models(model, model)
......
"""Trigger import of some modules to write some caches,
so that static analysis (e.g., pyright) can know the type."""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
import nni
import nni.retiarii.nn.pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment