"src/vscode:/vscode.git/clone" did not exist on "6545b0849b5400433529e164d8dd83756a019358"
Unverified Commit 5136a86d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Typehint and copyright header (#4669)

parent 68347c5e
...@@ -37,6 +37,8 @@ def to_v2(v1): ...@@ -37,6 +37,8 @@ def to_v2(v1):
_move_field(v1_trial, v2, 'command', 'trialCommand') _move_field(v1_trial, v2, 'command', 'trialCommand')
_move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory') _move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory')
_move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber') _move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber')
else:
v1_trial = {}
for algo_type in ['tuner', 'assessor', 'advisor']: for algo_type in ['tuner', 'assessor', 'advisor']:
v1_algo = v1.pop(algo_type, None) v1_algo = v1.pop(algo_type, None)
......
...@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase): ...@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
if self.password is not None: if self.password is not None:
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.') warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
elif not Path(self.ssh_key_file).is_file(): elif not Path(self.ssh_key_file).is_file(): # type: ignore
raise ValueError( raise ValueError(
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"' f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
) )
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass from dataclasses import dataclass
import json import json
from typing import List from typing import List
......
...@@ -10,7 +10,7 @@ from pathlib import Path ...@@ -10,7 +10,7 @@ from pathlib import Path
import socket import socket
from subprocess import Popen from subprocess import Popen
import time import time
from typing import Optional, Any from typing import Any
import colorama import colorama
import psutil import psutil
...@@ -34,8 +34,7 @@ class RunMode(Enum): ...@@ -34,8 +34,7 @@ class RunMode(Enum):
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout. - Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits. - Detach: do not stop NNI manager when Python script exits.
NOTE: NOTE: This API is non-stable and is likely to get refactored in upcoming release.
This API is non-stable and is likely to get refactored in next release.
""" """
# TODO: # TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose. # NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
...@@ -72,15 +71,15 @@ class Experiment: ...@@ -72,15 +71,15 @@ class Experiment:
Web portal port. Or ``None`` if the experiment is not running. Web portal port. Or ``None`` if the experiment is not running.
""" """
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None: def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None):
nni.runtime.log.init_logger_for_command_line() nni.runtime.log.init_logger_for_command_line()
self.config: Optional[ExperimentConfig] = None self.config: ExperimentConfig | None = None
self.id: str = management.generate_experiment_id() self.id: str = management.generate_experiment_id()
self.port: Optional[int] = None self.port: int | None = None
self._proc: Optional[Popen] = None self._proc: Popen | psutil.Process | None = None
self.action = 'create' self._action = 'create'
self.url_prefix: Optional[str] = None self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)): if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform) self.config = ExperimentConfig(config_or_platform)
...@@ -101,6 +100,7 @@ class Experiment: ...@@ -101,6 +100,7 @@ class Experiment:
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
assert self.config is not None
if run_mode is not RunMode.Detach: if run_mode is not RunMode.Detach:
atexit.register(self.stop) atexit.register(self.stop)
...@@ -114,7 +114,7 @@ class Experiment: ...@@ -114,7 +114,7 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log' log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug) nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.action, self.id, config, port, debug, run_mode, self.url_prefix) self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
assert self._proc is not None assert self._proc is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
...@@ -144,16 +144,16 @@ class Experiment: ...@@ -144,16 +144,16 @@ class Experiment:
_logger.warning('Cannot gracefully stop experiment, killing NNI process...') _logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid) kill_command(self._proc.pid)
self.id = None self.id = None # type: ignore
self.port = None self.port = None
self._proc = None self._proc = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool: def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
""" """
Run the experiment. Run the experiment.
If ``wait_completion`` is True, this function will block until experiment finish or error. If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
Return ``True`` when experiment done; or return ``False`` when experiment failed. Return ``True`` when experiment done; or return ``False`` when experiment failed.
...@@ -247,7 +247,7 @@ class Experiment: ...@@ -247,7 +247,7 @@ class Experiment:
def _resume(exp_id, exp_dir=None): def _resume(exp_id, exp_dir=None):
exp = Experiment(None) exp = Experiment(None)
exp.id = exp_id exp.id = exp_id
exp.action = 'resume' exp._action = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir) exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp return exp
...@@ -255,7 +255,7 @@ class Experiment: ...@@ -255,7 +255,7 @@ class Experiment:
def _view(exp_id, exp_dir=None): def _view(exp_id, exp_dir=None):
exp = Experiment(None) exp = Experiment(None)
exp.id = exp_id exp.id = exp_id
exp.action = 'view' exp._action = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir) exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp return exp
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import contextlib import contextlib
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from datetime import datetime from datetime import datetime
...@@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return proc return proc
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]: def _start_rest_server(nni_manager_args, run_mode) -> Popen:
import nni_node import nni_node
node_dir = Path(nni_node.__path__[0]) node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js') main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js] cmd = [node, '--max-old-space-size=4096', main_js]
...@@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]: ...@@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP) return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen: def start_experiment_retiarii(exp_id, config, port, debug):
pipe = None pipe = None
proc = None proc = None
...@@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool ...@@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
args['dispatcher_pipe'] = pipe_path args['dispatcher_pipe'] = pipe_path
import nni_node import nni_node
node_dir = Path(nni_node.__path__[0]) node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js') main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js] cmd = [node, '--max-old-space-size=4096', main_js]
...@@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int, ...@@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None): def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) config = ExperimentConfig(**config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory): if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)' msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory) _logger.warning(msg, exp_dir, config.experiment_working_directory)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path from pathlib import Path
import random import random
import string import string
......
from io import BufferedIOBase # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
import os import os
import sys import sys
...@@ -25,7 +27,7 @@ if sys.platform == 'win32': ...@@ -25,7 +27,7 @@ if sys.platform == 'win32':
_winapi.NULL _winapi.NULL
) )
def connect(self) -> BufferedIOBase: def connect(self):
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL) _winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0) fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b') self.file = os.fdopen(fd, 'w+b')
...@@ -55,7 +57,7 @@ else: ...@@ -55,7 +57,7 @@ else:
self._socket.bind(self.path) self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase: def connect(self):
conn, _ = self._socket.accept() conn, _ = self._socket.accept()
self.file = conn.makefile('rwb') self.file = conn.makefile('rwb')
return self.file return self.file
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
from typing import Any, Optional from typing import Any, Optional
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import os import os
class Recoverable: class Recoverable:
def load_checkpoint(self): def load_checkpoint(self) -> None:
pass pass
def save_checkpoint(self): def save_checkpoint(self) -> None:
pass pass
def get_checkpoint_path(self): def get_checkpoint_path(self) -> str | None:
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY') ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
if ckp_path is not None and os.path.isdir(ckp_path): if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path return ckp_path
......
...@@ -14,7 +14,7 @@ def get_config_directory() -> Path: ...@@ -14,7 +14,7 @@ def get_config_directory() -> Path:
Create it if not exist. Create it if not exist.
""" """
if os.getenv('NNI_CONFIG_DIR') is not None: if os.getenv('NNI_CONFIG_DIR') is not None:
config_dir = Path(os.getenv('NNI_CONFIG_DIR')) config_dir = Path(os.getenv('NNI_CONFIG_DIR')) # type: ignore
elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir(): elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir():
config_dir = Path(sys.prefix, 'nni') config_dir = Path(sys.prefix, 'nni')
elif sys.platform == 'win32': elif sys.platform == 'win32':
...@@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path: ...@@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
""" """
Get a readonly builtin config file. Get a readonly builtin config file.
""" """
return Path(nni.__path__[0], 'runtime/default_config', name) return Path(nni.__path__[0], 'runtime/default_config', name) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import sys import sys
from datetime import datetime from datetime import datetime
...@@ -105,7 +110,7 @@ def _init_logger_standalone() -> None: ...@@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
_register_handler(StreamHandler(sys.stdout), logging.INFO) _register_handler(StreamHandler(sys.stdout), logging.INFO)
def _prepare_log_dir(path: Optional[str]) -> Path: def _prepare_log_dir(path: Path | str) -> Path:
if path is None: if path is None:
return Path() return Path()
ret = Path(path) ret = Path(path)
...@@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase): ...@@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
def __init__(self, log_file: TextIOBase): def __init__(self, log_file: TextIOBase):
self.file: TextIOBase = log_file self.file: TextIOBase = log_file
self.line_buffer: Optional[str] = None self.line_buffer: Optional[str] = None
self.line_start_time: Optional[datetime] = None self.line_start_time: datetime = datetime.fromtimestamp(0)
def write(self, s: str) -> int: def write(self, s: str) -> int:
cur_time = datetime.now() cur_time = datetime.now()
......
...@@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
except Exception as e: except Exception as e:
_logger.error('Assessor error') _logger.error('Assessor error')
_logger.exception(e) _logger.exception(e)
raise
if isinstance(result, bool): if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad result = AssessResult.Good if result else AssessResult.Bad
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from . import proxy from . import proxy
load_jupyter_server_extension = proxy.setup load_jupyter_server_extension = proxy.setup
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
from pathlib import Path from pathlib import Path
import shutil import shutil
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
from pathlib import Path from pathlib import Path
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib import importlib
import json import json
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from .common.serializer import dump from .common.serializer import dump
from .runtime.env_vars import trial_env_vars from .runtime.env_vars import trial_env_vars
from .runtime import platform from .runtime import platform
from .typehint import Parameters, TrialMetric
__all__ = [ __all__ = [
'get_next_parameter', 'get_next_parameter',
'get_next_parameters',
'get_current_parameter', 'get_current_parameter',
'report_intermediate_result', 'report_intermediate_result',
'report_final_result', 'report_final_result',
...@@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id() ...@@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
_sequence_id = platform.get_sequence_id() _sequence_id = platform.get_sequence_id()
def get_next_parameter(): def get_next_parameter() -> Parameters:
""" """
Get the hyperparameters generated by tuner. Get the hyperparameters generated by tuner.
...@@ -32,7 +37,7 @@ def get_next_parameter(): ...@@ -32,7 +37,7 @@ def get_next_parameter():
Examples Examples
-------- --------
Assuming the search space is: Assuming the :doc:`search space </hpo/search_space>` is:
.. code-block:: .. code-block::
...@@ -52,16 +57,22 @@ def get_next_parameter(): ...@@ -52,16 +57,22 @@ def get_next_parameter():
Returns Returns
------- -------
dict :class:`~nni.typehint.Parameters`
A hyperparameter set sampled from search space. A hyperparameter set sampled from search space.
""" """
global _params global _params
_params = platform.get_next_parameter() _params = platform.get_next_parameter()
if _params is None: if _params is None:
return None return None # type: ignore
return _params['parameters'] return _params['parameters']
def get_current_parameter(tag=None): def get_next_parameters() -> Parameters:
"""
Alias of :func:`get_next_parameter`
"""
return get_next_parameter()
def get_current_parameter(tag: str | None = None) -> Any:
global _params global _params
if _params is None: if _params is None:
return None return None
...@@ -94,13 +105,13 @@ def get_sequence_id() -> int: ...@@ -94,13 +105,13 @@ def get_sequence_id() -> int:
_intermediate_seq = 0 _intermediate_seq = 0
def overwrite_intermediate_seq(value): def overwrite_intermediate_seq(value: int) -> None:
assert isinstance(value, int) assert isinstance(value, int)
global _intermediate_seq global _intermediate_seq
_intermediate_seq = value _intermediate_seq = value
def report_intermediate_result(metric): def report_intermediate_result(metric: TrialMetric | dict[str, Any]) -> None:
""" """
Reports intermediate result to NNI. Reports intermediate result to NNI.
...@@ -110,11 +121,16 @@ def report_intermediate_result(metric): ...@@ -110,11 +121,16 @@ def report_intermediate_result(metric):
and other items can be visualized with web portal. and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss. Typically ``metric`` is per-epoch accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The intermeidate result.
""" """
global _intermediate_seq global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result' 'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = dump({ dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
...@@ -122,9 +138,9 @@ def report_intermediate_result(metric): ...@@ -122,9 +138,9 @@ def report_intermediate_result(metric):
'value': dump(metric) 'value': dump(metric)
}) })
_intermediate_seq += 1 _intermediate_seq += 1
platform.send_metric(metric) platform.send_metric(dumped_metric)
def report_final_result(metric): def report_final_result(metric: TrialMetric | dict[str, Any]) -> None:
""" """
Reports final result to NNI. Reports final result to NNI.
...@@ -134,14 +150,19 @@ def report_final_result(metric): ...@@ -134,14 +150,19 @@ def report_final_result(metric):
and other items can be visualized with web portal. and other items can be visualized with web portal.
Typically ``metric`` is the final accuracy or loss. Typically ``metric`` is the final accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The final result.
""" """
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result' 'nni.get_next_parameter() needs to be called before report_final_result'
metric = dump({ dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, 'sequence': 0,
'value': dump(metric) 'value': dump(metric)
}) })
platform.send_metric(metric) platform.send_metric(dumped_metric)
...@@ -8,11 +8,14 @@ A new trial will run with this configuration. ...@@ -8,11 +8,14 @@ A new trial will run with this configuration.
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details. See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
""" """
from __future__ import annotations
import logging import logging
import nni import nni
from .recoverable import Recoverable from .recoverable import Recoverable
from .typehint import Parameters, SearchSpace, TrialMetric, TrialRecord
__all__ = ['Tuner'] __all__ = ['Tuner']
...@@ -67,7 +70,7 @@ class Tuner(Recoverable): ...@@ -67,7 +70,7 @@ class Tuner(Recoverable):
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner` :class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
""" """
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id: int, **kwargs) -> Parameters:
""" """
Abstract method which provides a set of hyper-parameters. Abstract method which provides a set of hyper-parameters.
...@@ -100,7 +103,7 @@ class Tuner(Recoverable): ...@@ -100,7 +103,7 @@ class Tuner(Recoverable):
# we need to design a new exception for this purpose # we need to design a new exception for this purpose
raise NotImplementedError('Tuner: generate_parameters not implemented') raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list, **kwargs): def generate_multiple_parameters(self, parameter_id_list: list[int], **kwargs) -> list[Parameters]:
""" """
Callback method which provides multiple sets of hyper-parameters. Callback method which provides multiple sets of hyper-parameters.
...@@ -135,7 +138,7 @@ class Tuner(Recoverable): ...@@ -135,7 +138,7 @@ class Tuner(Recoverable):
result.append(res) result.append(res)
return result return result
def receive_trial_result(self, parameter_id, parameters, value, **kwargs): def receive_trial_result(self, parameter_id: int, parameters: Parameters, value: TrialMetric, **kwargs) -> None:
""" """
Abstract method invoked when a trial reports its final result. Must override. Abstract method invoked when a trial reports its final result. Must override.
...@@ -165,7 +168,7 @@ class Tuner(Recoverable): ...@@ -165,7 +168,7 @@ class Tuner(Recoverable):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self._accept_customized = accept self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs): def trial_end(self, parameter_id: int, success: bool, **kwargs) -> None:
""" """
Abstract method invoked when a trial is completed or terminated. Do nothing by default. Abstract method invoked when a trial is completed or terminated. Do nothing by default.
...@@ -179,7 +182,7 @@ class Tuner(Recoverable): ...@@ -179,7 +182,7 @@ class Tuner(Recoverable):
Unstable parameters which should be ignored by normal users. Unstable parameters which should be ignored by normal users.
""" """
def update_search_space(self, search_space): def update_search_space(self, search_space: SearchSpace) -> None:
""" """
Abstract method for updating the search space. Must override. Abstract method for updating the search space. Must override.
...@@ -194,21 +197,21 @@ class Tuner(Recoverable): ...@@ -194,21 +197,21 @@ class Tuner(Recoverable):
""" """
raise NotImplementedError('Tuner: update_search_space not implemented') raise NotImplementedError('Tuner: update_search_space not implemented')
def load_checkpoint(self): def load_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) _logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self): def save_checkpoint(self) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) _logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def import_data(self, data): def import_data(self, data: list[TrialRecord]) -> None:
""" """
Internal API under revising, not recommended for end users. Internal API under revising, not recommended for end users.
""" """
...@@ -216,8 +219,8 @@ class Tuner(Recoverable): ...@@ -216,8 +219,8 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' # data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass pass
def _on_exit(self): def _on_exit(self) -> None:
pass pass
def _on_error(self): def _on_error(self) -> None:
pass pass
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
Types for static checking.
"""
__all__ = [
'Literal',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
]
import sys import sys
import typing from typing import Any, Dict, List, TYPE_CHECKING
if typing.TYPE_CHECKING or sys.version_info >= (3, 8): if TYPE_CHECKING or sys.version_info >= (3, 8):
Literal = typing.Literal from typing import Literal, TypedDict
else: else:
Literal = typing.Any from typing_extensions import Literal, TypedDict
Parameters = Dict[str, Any]
"""
Return type of :func:`nni.get_next_parameter`.
For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space </hpo/search_space>`.
Customized tuners do not need to follow the constraint and can use anything serializable.
"""
class _ParameterSearchSpace(TypedDict):
_type: Literal[
'choice', 'randint',
'uniform', 'loguniform', 'quniform', 'qloguniform',
'normal', 'lognormal', 'qnormal', 'qlognormal',
]
_value: List[Any]
SearchSpace = Dict[str, _ParameterSearchSpace]
"""
Type of ``experiment.config.search_space``.
For built-in tuners, the format is detailed in :doc:`/hpo/search_space`.
Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``.
"""
TrialMetric = float
"""
Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`.
For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc).
Customized tuners do not need to follow this constraint and can use anything serializable.
"""
class TrialRecord(TypedDict):
parameter: Parameters
value: TrialMetric
...@@ -63,6 +63,9 @@ stages: ...@@ -63,6 +63,9 @@ stages:
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: flake8 displayName: flake8
- script: |
python -m pyright nni
- job: typescript - job: typescript
pool: pool:
vmImage: ubuntu-latest vmImage: ubuntu-latest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment