Unverified Commit 5136a86d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Typehint and copyright header (#4669)

parent 68347c5e
...@@ -5,6 +5,7 @@ ipython ...@@ -5,6 +5,7 @@ ipython
jupyterlab jupyterlab
nbsphinx nbsphinx
pylint pylint
pyright
pytest pytest
pytest-azurepipelines pytest-azurepipelines
pytest-cov pytest-cov
......
...@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1 ...@@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8" scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8" scipy ; python_version >= "3.8"
typeguard typeguard
typing_extensions ; python_version < "3.8" typing_extensions >= 4.0.0 ; python_version < "3.8"
websockets >= 10.1 websockets >= 10.1
Uncategorized Modules
=====================
nni.typehint
------------
.. automodule:: nni.typehint
:members:
Others
======
nni
---
nni.common
----------
nni.utils
---------
...@@ -9,4 +9,4 @@ API Reference ...@@ -9,4 +9,4 @@ API Reference
Model Compression <compression> Model Compression <compression>
Feature Engineering <./python_api/feature_engineering> Feature Engineering <./python_api/feature_engineering>
Experiment <experiment> Experiment <experiment>
Others <./python_api/others> Others <others>
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .bohb_advisor import BOHB, BOHBClassArgsValidator from .bohb_advisor import BOHB, BOHBClassArgsValidator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
import warnings import warnings
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gp_tuner import GPTuner, GPClassArgsValidator from .gp_tuner import GPTuner, GPClassArgsValidator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .metis_tuner import MetisTuner, MetisClassArgsValidator from .metis_tuner import MetisTuner, MetisClassArgsValidator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .networkmorphism_tuner import NetworkMorphismTuner, NetworkMorphismClassArgsValidator from .networkmorphism_tuner import NetworkMorphismTuner, NetworkMorphismClassArgsValidator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .ppo_tuner import PPOTuner, PPOClassArgsValidator from .ppo_tuner import PPOTuner, PPOClassArgsValidator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
import logging import logging
import random import random
......
...@@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not. ...@@ -8,10 +8,13 @@ to tell whether this trial can be early stopped or not.
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details. See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details.
""" """
from __future__ import annotations
from enum import Enum from enum import Enum
import logging import logging
from .recoverable import Recoverable from .recoverable import Recoverable
from .typehint import TrialMetric
__all__ = ['AssessResult', 'Assessor'] __all__ = ['AssessResult', 'Assessor']
...@@ -54,7 +57,7 @@ class Assessor(Recoverable): ...@@ -54,7 +57,7 @@ class Assessor(Recoverable):
:class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor` :class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor`
""" """
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id: str, trial_history: list[TrialMetric]) -> AssessResult:
""" """
Abstract method for determining whether a trial should be killed. Must override. Abstract method for determining whether a trial should be killed. Must override.
...@@ -91,7 +94,7 @@ class Assessor(Recoverable): ...@@ -91,7 +94,7 @@ class Assessor(Recoverable):
""" """
raise NotImplementedError('Assessor: assess_trial not implemented') raise NotImplementedError('Assessor: assess_trial not implemented')
def trial_end(self, trial_job_id, success): def trial_end(self, trial_job_id: str, success: bool) -> None:
""" """
Abstract method invoked when a trial is completed or terminated. Do nothing by default. Abstract method invoked when a trial is completed or terminated. Do nothing by default.
...@@ -103,22 +106,22 @@ class Assessor(Recoverable): ...@@ -103,22 +106,22 @@ class Assessor(Recoverable):
True if the trial successfully completed; False if failed or terminated. True if the trial successfully completed; False if failed or terminated.
""" """
def load_checkpoint(self): def load_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) _logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self): def save_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) _logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def _on_exit(self): def _on_exit(self) -> None:
pass pass
def _on_error(self): def _on_error(self) -> None:
pass pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .serializer import trace, dump, load, is_traceable from .serializer import trace, dump, load, is_traceable
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
Helper class and functions for tuners to deal with search space.
This script provides a more program-friendly representation of HPO search space. This script provides a more program-friendly representation of HPO search space.
The format is considered internal helper and is not visible to end users. The format is considered internal helper and is not visible to end users.
...@@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space. ...@@ -9,8 +11,16 @@ You will find this useful when you want to support nested search space.
The random tuner is an intuitive example for this utility. The random tuner is an intuitive example for this utility.
You should check its code before reading docstrings in this file. You should check its code before reading docstrings in this file.
.. attention::
This module does not guarantee forward-compatibility.
If you want to use it outside official NNI repo, it is recommended to copy the script.
""" """
from __future__ import annotations
__all__ = [ __all__ = [
'ParameterSpec', 'ParameterSpec',
'deformat_parameters', 'deformat_parameters',
...@@ -20,10 +30,16 @@ __all__ = [ ...@@ -20,10 +30,16 @@ __all__ = [
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, Dict, NamedTuple, Tuple, cast
import numpy as np import numpy as np
from nni.typehint import Parameters, SearchSpace
ParameterKey = Tuple['str | int', ...]
FormattedParameters = Dict[ParameterKey, 'float | int']
FormattedSearchSpace = Dict[ParameterKey, 'ParameterSpec']
class ParameterSpec(NamedTuple): class ParameterSpec(NamedTuple):
""" """
Specification (aka space / range / domain) of one single parameter. Specification (aka space / range / domain) of one single parameter.
...@@ -33,29 +49,31 @@ class ParameterSpec(NamedTuple): ...@@ -33,29 +49,31 @@ class ParameterSpec(NamedTuple):
name: str # The object key in JSON name: str # The object key in JSON
type: str # "_type" in JSON type: str # "_type" in JSON
values: List[Any] # "_value" in JSON values: list[Any] # "_value" in JSON
key: Tuple[str] # The "path" of this parameter key: ParameterKey # The "path" of this parameter
categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered) categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered)
size: int = None # If it's categorical, how many candidates it has size: int = cast(int, None) # If it's categorical, how many candidates it has
# uniform distributed # uniform distributed
low: float = None # Lower bound of uniform parameter low: float = cast(float, None) # Lower bound of uniform parameter
high: float = None # Upper bound of uniform parameter high: float = cast(float, None) # Upper bound of uniform parameter
normal_distributed: bool = None # Whether this parameter is uniform or normal distrubuted normal_distributed: bool = cast(bool, None)
mu: float = None # µ of normal parameter # Whether this parameter is uniform or normal distrubuted
sigma: float = None # σ of normal parameter mu: float = cast(float, None) # µ of normal parameter
sigma: float = cast(float, None)# σ of normal parameter
q: Optional[float] = None # If not `None`, the parameter value should be an integer multiple of this q: float | None = None # If not `None`, the parameter value should be an integer multiple of this
clip: Optional[Tuple[float, float]] = None clip: tuple[float, float] | None = None
# For q(log)uniform, this equals to "values[:2]"; for others this is None # For q(log)uniform, this equals to "values[:2]"; for others this is None
log_distributed: bool = None # Whether this parameter is log distributed log_distributed: bool = cast(bool, None)
# Whether this parameter is log distributed
# When true, low/high/mu/sigma describes log of parameter value (like np.lognormal) # When true, low/high/mu/sigma describes log of parameter value (like np.lognormal)
def is_activated_in(self, partial_parameters): def is_activated_in(self, partial_parameters: FormattedParameters) -> bool:
""" """
For nested search space, check whether this parameter should be skipped for current set of paremters. For nested search space, check whether this parameter should be skipped for current set of paremters.
This function must be used in a pattern similar to random tuner. Otherwise it will misbehave. This function must be used in a pattern similar to random tuner. Otherwise it will misbehave.
...@@ -64,7 +82,7 @@ class ParameterSpec(NamedTuple): ...@@ -64,7 +82,7 @@ class ParameterSpec(NamedTuple):
return True return True
return partial_parameters[self.key[:-2]] == self.key[-2] return partial_parameters[self.key[:-2]] == self.key[-2]
def format_search_space(search_space): def format_search_space(search_space: SearchSpace) -> FormattedSearchSpace:
""" """
Convert user provided search space into a dict of ParameterSpec. Convert user provided search space into a dict of ParameterSpec.
The dict key is dict value's `ParameterSpec.key`. The dict key is dict value's `ParameterSpec.key`.
...@@ -76,7 +94,9 @@ def format_search_space(search_space): ...@@ -76,7 +94,9 @@ def format_search_space(search_space):
# Remove these comments when we drop 3.6 support. # Remove these comments when we drop 3.6 support.
return {spec.key: spec for spec in formatted} return {spec.key: spec for spec in formatted}
def deformat_parameters(formatted_parameters, formatted_search_space): def deformat_parameters(
formatted_parameters: FormattedParameters,
formatted_search_space: FormattedSearchSpace) -> Parameters:
""" """
Convert internal format parameters to users' expected format. Convert internal format parameters to users' expected format.
...@@ -88,10 +108,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space): ...@@ -88,10 +108,11 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
3. For "q*", convert x to `round(x / q) * q`, then clip into range. 3. For "q*", convert x to `round(x / q) * q`, then clip into range.
4. For nested choices, convert flatten key-value pairs into nested structure. 4. For nested choices, convert flatten key-value pairs into nested structure.
""" """
ret = {} ret: Parameters = {}
for key, x in formatted_parameters.items(): for key, x in formatted_parameters.items():
spec = formatted_search_space[key] spec = formatted_search_space[key]
if spec.categorical: if spec.categorical:
x = cast(int, x)
if spec.type == 'randint': if spec.type == 'randint':
lower = min(math.ceil(float(x)) for x in spec.values) lower = min(math.ceil(float(x)) for x in spec.values)
_assign(ret, key, int(lower + x)) _assign(ret, key, int(lower + x))
...@@ -112,7 +133,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space): ...@@ -112,7 +133,7 @@ def deformat_parameters(formatted_parameters, formatted_search_space):
_assign(ret, key, x) _assign(ret, key, x)
return ret return ret
def format_parameters(parameters, formatted_search_space): def format_parameters(parameters: Parameters, formatted_search_space: FormattedSearchSpace) -> FormattedParameters:
""" """
Convert end users' parameter format back to internal format, mainly for resuming experiments. Convert end users' parameter format back to internal format, mainly for resuming experiments.
...@@ -123,7 +144,7 @@ def format_parameters(parameters, formatted_search_space): ...@@ -123,7 +144,7 @@ def format_parameters(parameters, formatted_search_space):
for key, spec in formatted_search_space.items(): for key, spec in formatted_search_space.items():
if not spec.is_activated_in(ret): if not spec.is_activated_in(ret):
continue continue
value = parameters value: Any = parameters
for name in key: for name in key:
if isinstance(name, str): if isinstance(name, str):
value = value[name] value = value[name]
...@@ -142,8 +163,8 @@ def format_parameters(parameters, formatted_search_space): ...@@ -142,8 +163,8 @@ def format_parameters(parameters, formatted_search_space):
ret[key] = value ret[key] = value
return ret return ret
def _format_search_space(parent_key, space): def _format_search_space(parent_key: ParameterKey, space: SearchSpace) -> list[ParameterSpec]:
formatted = [] formatted: list[ParameterSpec] = []
for name, spec in space.items(): for name, spec in space.items():
if name == '_name': if name == '_name':
continue continue
...@@ -155,7 +176,7 @@ def _format_search_space(parent_key, space): ...@@ -155,7 +176,7 @@ def _format_search_space(parent_key, space):
formatted += _format_search_space(key, sub_space) formatted += _format_search_space(key, sub_space)
return formatted return formatted
def _format_parameter(key, type_, values): def _format_parameter(key: ParameterKey, type_: str, values: list[Any]):
spec = SimpleNamespace( spec = SimpleNamespace(
name = key[-1], name = key[-1],
type = type_, type = type_,
...@@ -197,7 +218,7 @@ def _format_parameter(key, type_, values): ...@@ -197,7 +218,7 @@ def _format_parameter(key, type_, values):
return ParameterSpec(**spec.__dict__) return ParameterSpec(**spec.__dict__)
def _is_nested_choices(values): def _is_nested_choices(values: list[Any]) -> bool:
assert values # choices should not be empty assert values # choices should not be empty
for value in values: for value in values:
if not isinstance(value, dict): if not isinstance(value, dict):
...@@ -206,9 +227,9 @@ def _is_nested_choices(values): ...@@ -206,9 +227,9 @@ def _is_nested_choices(values):
return False return False
return True return True
def _assign(params, key, x): def _assign(params: Parameters, key: ParameterKey, x: Any) -> None:
if len(key) == 1: if len(key) == 1:
params[key[0]] = x params[cast(str, key[0])] = x
elif isinstance(key[0], int): elif isinstance(key[0], int):
_assign(params, key[1:], x) _assign(params, key[1:], x)
else: else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum from enum import Enum
class OptimizeMode(Enum): class OptimizeMode(Enum):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
from typing import Any, List, Optional from typing import Any
common_search_space_types = [ common_search_space_types = [
'choice', 'choice',
...@@ -19,7 +21,7 @@ common_search_space_types = [ ...@@ -19,7 +21,7 @@ common_search_space_types = [
def validate_search_space( def validate_search_space(
search_space: Any, search_space: Any,
support_types: Optional[List[str]] = None, support_types: list[str] | None = None,
raise_exception: bool = False # for now, in case false positive raise_exception: bool = False # for now, in case false positive
) -> bool: ) -> bool:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc import abc
import base64 import base64
import collections.abc import collections.abc
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
try: try:
import torch import torch
......
...@@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase): ...@@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase):
else: # custom algorithm else: # custom algorithm
assert self.name is None assert self.name is None
assert self.class_name assert self.class_name
assert self.code_directory is not None
if not Path(self.code_directory).is_dir(): if not Path(self.code_directory).is_dir():
raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory') raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment