Unverified Commit 5136a86d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Typehint and copyright header (#4669)

parent 68347c5e
......@@ -37,6 +37,8 @@ def to_v2(v1):
_move_field(v1_trial, v2, 'command', 'trialCommand')
_move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory')
_move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber')
else:
v1_trial = {}
for algo_type in ['tuner', 'assessor', 'advisor']:
v1_algo = v1.pop(algo_type, None)
......
......@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
if self.password is not None:
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
elif not Path(self.ssh_key_file).is_file():
elif not Path(self.ssh_key_file).is_file(): # type: ignore
raise ValueError(
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
import json
from typing import List
......
......@@ -10,7 +10,7 @@ from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Optional, Any
from typing import Any
import colorama
import psutil
......@@ -34,8 +34,7 @@ class RunMode(Enum):
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
NOTE:
This API is non-stable and is likely to get refactored in next release.
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
"""
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
......@@ -72,15 +71,15 @@ class Experiment:
Web portal port. Or ``None`` if the experiment is not running.
"""
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None:
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None):
nni.runtime.log.init_logger_for_command_line()
self.config: Optional[ExperimentConfig] = None
self.config: ExperimentConfig | None = None
self.id: str = management.generate_experiment_id()
self.port: Optional[int] = None
self._proc: Optional[Popen] = None
self.action = 'create'
self.url_prefix: Optional[str] = None
self.port: int | None = None
self._proc: Popen | psutil.Process | None = None
self._action = 'create'
self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform)
......@@ -101,6 +100,7 @@ class Experiment:
debug
Whether to start in debug mode.
"""
assert self.config is not None
if run_mode is not RunMode.Detach:
atexit.register(self.stop)
......@@ -114,7 +114,7 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.action, self.id, config, port, debug, run_mode, self.url_prefix)
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
assert self._proc is not None
self.port = port # port will be None if start up failed
......@@ -144,16 +144,16 @@ class Experiment:
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.id = None
self.id = None # type: ignore
self.port = None
self._proc = None
_logger.info('Experiment stopped')
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool:
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
If ``wait_completion`` is True, this function will block until experiment finish or error.
If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
......@@ -247,7 +247,7 @@ class Experiment:
def _resume(exp_id, exp_dir=None):
exp = Experiment(None)
exp.id = exp_id
exp.action = 'resume'
exp._action = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
......@@ -255,7 +255,7 @@ class Experiment:
def _view(exp_id, exp_dir=None):
exp = Experiment(None)
exp.id = exp_id
exp.action = 'view'
exp._action = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import contextlib
from dataclasses import dataclass, fields
from datetime import datetime
......@@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return proc
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
def _start_rest_server(nni_manager_args, run_mode) -> Popen:
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
......@@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
from subprocess import CREATE_NEW_PROCESS_GROUP
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp)
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
def start_experiment_retiarii(exp_id, config, port, debug):
pipe = None
proc = None
......@@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
args['dispatcher_pipe'] = pipe_path
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
......@@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir)
config = ExperimentConfig(**config_json)
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import random
import string
......
from io import BufferedIOBase
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
......@@ -25,7 +27,7 @@ if sys.platform == 'win32':
_winapi.NULL
)
def connect(self) -> BufferedIOBase:
def connect(self):
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b')
......@@ -55,7 +57,7 @@ else:
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
def connect(self):
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
return self.file
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Optional
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import os
class Recoverable:
def load_checkpoint(self):
def load_checkpoint(self) -> None:
pass
def save_checkpoint(self):
def save_checkpoint(self) -> None:
pass
def get_checkpoint_path(self):
def get_checkpoint_path(self) -> str | None:
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path
......
......@@ -14,7 +14,7 @@ def get_config_directory() -> Path:
Create it if not exist.
"""
if os.getenv('NNI_CONFIG_DIR') is not None:
config_dir = Path(os.getenv('NNI_CONFIG_DIR'))
config_dir = Path(os.getenv('NNI_CONFIG_DIR')) # type: ignore
elif sys.prefix != sys.base_prefix or Path(sys.prefix, 'conda-meta').is_dir():
config_dir = Path(sys.prefix, 'nni')
elif sys.platform == 'win32':
......@@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
"""
Get a readonly builtin config file.
"""
return Path(nni.__path__[0], 'runtime/default_config', name)
return Path(nni.__path__[0], 'runtime/default_config', name) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import sys
from datetime import datetime
......@@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
_register_handler(StreamHandler(sys.stdout), logging.INFO)
def _prepare_log_dir(path: Optional[str]) -> Path:
def _prepare_log_dir(path: Path | str) -> Path:
if path is None:
return Path()
ret = Path(path)
......@@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
def __init__(self, log_file: TextIOBase):
self.file: TextIOBase = log_file
self.line_buffer: Optional[str] = None
self.line_start_time: Optional[datetime] = None
self.line_start_time: datetime = datetime.fromtimestamp(0)
def write(self, s: str) -> int:
cur_time = datetime.now()
......
......@@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
except Exception as e:
_logger.error('Assessor error')
_logger.exception(e)
raise
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from . import proxy
load_jupyter_server_extension = proxy.setup
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from pathlib import Path
import shutil
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from pathlib import Path
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import json
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from .common.serializer import dump
from .runtime.env_vars import trial_env_vars
from .runtime import platform
from .typehint import Parameters, TrialMetric
__all__ = [
'get_next_parameter',
'get_next_parameters',
'get_current_parameter',
'report_intermediate_result',
'report_final_result',
......@@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
_sequence_id = platform.get_sequence_id()
def get_next_parameter():
def get_next_parameter() -> Parameters:
"""
Get the hyperparameters generated by tuner.
......@@ -32,7 +37,7 @@ def get_next_parameter():
Examples
--------
Assuming the search space is:
Assuming the :doc:`search space </hpo/search_space>` is:
.. code-block::
......@@ -52,16 +57,22 @@ def get_next_parameter():
Returns
-------
dict
:class:`~nni.typehint.Parameters`
A hyperparameter set sampled from search space.
"""
global _params
_params = platform.get_next_parameter()
if _params is None:
return None
return None # type: ignore
return _params['parameters']
def get_current_parameter(tag=None):
def get_next_parameters() -> Parameters:
"""
Alias of :func:`get_next_parameter`
"""
return get_next_parameter()
def get_current_parameter(tag: str | None = None) -> Any:
global _params
if _params is None:
return None
......@@ -94,13 +105,13 @@ def get_sequence_id() -> int:
_intermediate_seq = 0
def overwrite_intermediate_seq(value):
def overwrite_intermediate_seq(value: int) -> None:
assert isinstance(value, int)
global _intermediate_seq
_intermediate_seq = value
def report_intermediate_result(metric):
def report_intermediate_result(metric: TrialMetric | dict[str, Any]) -> None:
"""
Reports intermediate result to NNI.
......@@ -110,11 +121,16 @@ def report_intermediate_result(metric):
and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The intermeidate result.
"""
global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = dump({
dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
......@@ -122,9 +138,9 @@ def report_intermediate_result(metric):
'value': dump(metric)
})
_intermediate_seq += 1
platform.send_metric(metric)
platform.send_metric(dumped_metric)
def report_final_result(metric):
def report_final_result(metric: TrialMetric | dict[str, Any]) -> None:
"""
Reports final result to NNI.
......@@ -134,14 +150,19 @@ def report_final_result(metric):
and other items can be visualized with web portal.
Typically ``metric`` is the final accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The final result.
"""
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
metric = dump({
dumped_metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0,
'value': dump(metric)
})
platform.send_metric(metric)
platform.send_metric(dumped_metric)
......@@ -8,11 +8,14 @@ A new trial will run with this configuration.
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
"""
from __future__ import annotations
import logging
import nni
from .recoverable import Recoverable
from .typehint import Parameters, SearchSpace, TrialMetric, TrialRecord
__all__ = ['Tuner']
......@@ -67,7 +70,7 @@ class Tuner(Recoverable):
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
"""
def generate_parameters(self, parameter_id, **kwargs):
def generate_parameters(self, parameter_id: int, **kwargs) -> Parameters:
"""
Abstract method which provides a set of hyper-parameters.
......@@ -100,7 +103,7 @@ class Tuner(Recoverable):
# we need to design a new exception for this purpose
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
def generate_multiple_parameters(self, parameter_id_list: list[int], **kwargs) -> list[Parameters]:
"""
Callback method which provides multiple sets of hyper-parameters.
......@@ -135,7 +138,7 @@ class Tuner(Recoverable):
result.append(res)
return result
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
def receive_trial_result(self, parameter_id: int, parameters: Parameters, value: TrialMetric, **kwargs) -> None:
"""
Abstract method invoked when a trial reports its final result. Must override.
......@@ -165,7 +168,7 @@ class Tuner(Recoverable):
# pylint: disable=attribute-defined-outside-init
self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs):
def trial_end(self, parameter_id: int, success: bool, **kwargs) -> None:
"""
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
......@@ -179,7 +182,7 @@ class Tuner(Recoverable):
Unstable parameters which should be ignored by normal users.
"""
def update_search_space(self, search_space):
def update_search_space(self, search_space: SearchSpace) -> None:
"""
Abstract method for updating the search space. Must override.
......@@ -194,21 +197,21 @@ class Tuner(Recoverable):
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def load_checkpoint(self):
def load_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self):
def save_checkpoint(self) -> None:
"""
Internal API under revising, not recommended for end users.
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def import_data(self, data):
def import_data(self, data: list[TrialRecord]) -> None:
"""
Internal API under revising, not recommended for end users.
"""
......@@ -216,8 +219,8 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
def _on_exit(self):
def _on_exit(self) -> None:
pass
def _on_error(self):
def _on_error(self) -> None:
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Types for static checking.
"""
__all__ = [
'Literal',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
]
import sys
import typing
from typing import Any, Dict, List, TYPE_CHECKING
if typing.TYPE_CHECKING or sys.version_info >= (3, 8):
Literal = typing.Literal
if TYPE_CHECKING or sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
Literal = typing.Any
from typing_extensions import Literal, TypedDict
Parameters = Dict[str, Any]
"""
Return type of :func:`nni.get_next_parameter`.
For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space </hpo/search_space>`.
Customized tuners do not need to follow the constraint and can use anything serializable.
"""
class _ParameterSearchSpace(TypedDict):
_type: Literal[
'choice', 'randint',
'uniform', 'loguniform', 'quniform', 'qloguniform',
'normal', 'lognormal', 'qnormal', 'qlognormal',
]
_value: List[Any]
SearchSpace = Dict[str, _ParameterSearchSpace]
"""
Type of ``experiment.config.search_space``.
For built-in tuners, the format is detailed in :doc:`/hpo/search_space`.
Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``.
"""
TrialMetric = float
"""
Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`.
For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc).
Customized tuners do not need to follow this constraint and can use anything serializable.
"""
class TrialRecord(TypedDict):
parameter: Parameters
value: TrialMetric
......@@ -63,6 +63,9 @@ stages:
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: flake8
- script: |
python -m pyright nni
- job: typescript
pool:
vmImage: ubuntu-latest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment