Unverified Commit 553e91f4 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Update trial and experiment docstr (#4672)

parent 22165cea
......@@ -4,7 +4,6 @@ HPO API Reference
Trial APIs
----------
.. autofunction:: nni.get_current_parameter
.. autofunction:: nni.get_experiment_id
.. autofunction:: nni.get_next_parameter
.. autofunction:: nni.get_sequence_id
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import inspect
from pathlib import Path, PurePath
from typing import overload, Union, List
from nni.experiment import Experiment, ExperimentConfig
from nni.algorithms.compression.pytorch.auto_compress.interface import AbstractAutoCompressionModule
......@@ -11,49 +12,19 @@ from nni.algorithms.compression.pytorch.auto_compress.interface import AbstractA
class AutoCompressionExperiment(Experiment):
@overload
def __init__(self, auto_compress_module: AbstractAutoCompressionModule, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.run()` to launch it.
Parameters
----------
auto_compress_module
The module provided by the user implements the `AbstractAutoCompressionModule` interfaces.
Remember put the module file under `trial_code_directory`.
config
Experiment configuration.
"""
...
@overload
def __init__(self, auto_compress_module: AbstractAutoCompressionModule, training_service: Union[str, List[str]]) -> None:
def __init__(self, auto_compress_module: AbstractAutoCompressionModule, config_or_platform: ExperimentConfig | str | list[str]) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment(auto_compress_module, 'remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Prepare an auto compression experiment.
Parameters
----------
auto_compress_module
The module provided by the user implements the `AbstractAutoCompressionModule` interfaces.
Remember put the module file under `trial_code_directory`.
training_service
Name of training service.
Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service.
config_or_platform
Experiment configuration or training service name.
"""
...
def __init__(self, auto_compress_module: AbstractAutoCompressionModule, config=None, training_service=None):
super().__init__(config, training_service)
super().__init__(config_or_platform)
self.module_file_path = str(PurePath(inspect.getfile(auto_compress_module)))
self.module_name = auto_compress_module.__name__
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import atexit
from enum import Enum
import logging
......@@ -5,7 +10,7 @@ from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Optional, Union, List, overload, Any
from typing import Optional, Any
import colorama
import psutil
......@@ -25,66 +30,49 @@ class RunMode(Enum):
"""
Config lifecycle and ouput redirection of NNI manager process.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
NOTE:
This API is non-stable and is likely to get refactored in next release.
NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
"""
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
Background = 'background'
Foreground = 'foreground'
Detach = 'detach'
class Experiment:
"""
Create and stop an NNI experiment.
Manage NNI experiment.
You can either specify an :class:`ExperimentConfig` object, or a training service name.
If a platform name is used, a blank config template for that training service will be generated.
When configuration is completed, use :meth:`Experiment.run` to launch the experiment.
Example
-------
.. code-block::
experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Attributes
----------
config
Experiment configuration.
id
Experiment ID.
port
Web UI port of the experiment, or `None` if it is not running.
Web portal port. Or ``None`` if the experiment is not running.
"""
@overload
def __init__(self, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.run()` to launch it.
Parameters
----------
config
Experiment configuration.
"""
...
@overload
def __init__(self, training_service: Union[str, List[str]]) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Parameters
----------
training_service
Name of training service.
Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service.
"""
...
def __init__(self, config=None, training_service=None):
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None) -> None:
nni.runtime.log.init_logger_for_command_line()
self.config: Optional[ExperimentConfig] = None
......@@ -94,11 +82,10 @@ class Experiment:
self.mode = 'new'
self.url_prefix: Optional[str] = None
args = [config, training_service] # deal with overloading
if isinstance(args[0], (str, list)):
self.config = ExperimentConfig(args[0])
if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform)
else:
self.config = args[0]
self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
......@@ -143,7 +130,7 @@ class Experiment:
def stop(self) -> None:
"""
Stop background experiment.
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
......@@ -166,11 +153,11 @@ class Experiment:
"""
Run the experiment.
If wait_completion is True, this function will block until experiment finish or error.
If ``wait_completion`` is True, this function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
Else if wait_completion is False, this function will non-block and return None immediately.
Else if ``wait_completion`` is ``False``, this function will non-block and return None immediately.
"""
self.start(port, debug)
if wait_completion:
......@@ -196,7 +183,7 @@ class Experiment:
port
The port of web UI.
"""
experiment = Experiment()
experiment = Experiment(None)
experiment.port = port
experiment.id = experiment.get_experiment_profile().get('id')
status = experiment.get_status()
......@@ -258,7 +245,7 @@ class Experiment:
@staticmethod
def _resume(exp_id, exp_dir=None):
exp = Experiment()
exp = Experiment(None)
exp.id = exp_id
exp.mode = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
......@@ -266,7 +253,7 @@ class Experiment:
@staticmethod
def _view(exp_id, exp_dir=None):
exp = Experiment()
exp = Experiment(None)
exp.id = exp_id
exp.mode = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
......
......@@ -25,16 +25,35 @@ _sequence_id = platform.get_sequence_id()
def get_next_parameter():
"""
Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call.
This API should be called only once in each trial job of an experiment which is not specified as multiphase.
Get the hyperparameters generated by tuner.
Each trial should and should only invoke this function once.
Otherwise the behavior is undefined.
Examples
--------
Assuming the search space is:
.. code-block::
{
'activation': {'_type': 'choice', '_value': ['relu', 'tanh', 'sigmoid']},
'learning_rate': {'_type': 'loguniform', '_value': [0.0001, 0.1]}
}
Then this function might return:
.. code-block::
{
'activation': 'relu',
'learning_rate': 0.02
}
Returns
-------
dict
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in
search space. Returns None if no more hyper parameters can be generated by tuner.
A hyperparameter set sampled from search space.
"""
global _params
_params = platform.get_next_parameter()
......@@ -43,15 +62,6 @@ def get_next_parameter():
return _params['parameters']
def get_current_parameter(tag=None):
"""
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
call of get_next_parameter returns.
Parameters
----------
tag: str
hyper parameter key
"""
global _params
if _params is None:
return None
......@@ -59,39 +69,25 @@ def get_current_parameter(tag=None):
return _params['parameters']
return _params['parameters'][tag]
def get_experiment_id():
def get_experiment_id() -> str:
"""
Get experiment ID.
Returns
-------
str
Identifier of current experiment
Return experiment ID.
"""
return _experiment_id
def get_trial_id():
def get_trial_id() -> str:
"""
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial
job has an unique string ID.
Return unique ID of the trial that is current running.
Returns
-------
str
Identifier of current trial job which is calling this API.
This is shown as "ID" in the web portal's trial table.
"""
return _trial_id
def get_sequence_id():
def get_sequence_id() -> int:
"""
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
are unique for each trial job, they are of different data types.
Return sequence nubmer of the trial that is currently running.
Returns
-------
int
Sequence number of current trial job which is calling this API.
This is shown as "Trial No." in the web portal's trial table.
"""
return _sequence_id
......@@ -99,14 +95,6 @@ _intermediate_seq = 0
def overwrite_intermediate_seq(value):
"""
Overwrite intermediate sequence value.
Parameters
----------
value:
int
"""
assert isinstance(value, int)
global _intermediate_seq
_intermediate_seq = value
......@@ -116,10 +104,12 @@ def report_intermediate_result(metric):
"""
Reports intermediate result to NNI.
Parameters
----------
metric:
serializable object.
``metric`` should either be a float, or a dict that ``metric['default']`` is a float.
If ``metric`` is a dict, ``metric['default']`` will be used by tuner,
and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss.
"""
global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \
......@@ -138,11 +128,12 @@ def report_final_result(metric):
"""
Reports final result to NNI.
Parameters
----------
metric: serializable object
Usually (for built-in tuners to work), it should be a number, or
a dict with key "default" (a number), and any other extra keys.
``metric`` should either be a float, or a dict that ``metric['default']`` is a float.
If ``metric`` is a dict, ``metric['default']`` will be used by tuner,
and other items can be visualized with web portal.
Typically ``metric`` is the final accuracy or loss.
"""
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment