"src/vscode:/vscode.git/clone" did not exist on "5e5c27a63b1637556a17e17546147da6cb6d732e"
Unverified Commit 8735fa58 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Update trial doc string (#1713)

* trial docstring
parent f5803f68
...@@ -43,8 +43,18 @@ _sequence_id = platform.get_sequence_id() ...@@ -43,8 +43,18 @@ _sequence_id = platform.get_sequence_id()
def get_next_parameter(): def get_next_parameter():
"""Returns a set of (hyper-)paremeters generated by Tuner. """
Returns None if no more (hyper-)parameters can be generated by Tuner.""" Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call.
This API should be called only once in each trial job of an experiment which is not specified as multiphase.
Returns
-------
dict
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in
search space. Returns None if no more hyper parameters can be generated by tuner.
"""
global _params global _params
_params = platform.get_next_parameter() _params = platform.get_next_parameter()
if _params is None: if _params is None:
...@@ -52,6 +62,15 @@ def get_next_parameter(): ...@@ -52,6 +62,15 @@ def get_next_parameter():
return _params['parameters'] return _params['parameters']
def get_current_parameter(tag=None): def get_current_parameter(tag=None):
"""
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
call of get_next_parameter returns.
Parameters
----------
tag: str
hyper parameter key
"""
global _params global _params
if _params is None: if _params is None:
return None return None
...@@ -60,19 +79,51 @@ def get_current_parameter(tag=None): ...@@ -60,19 +79,51 @@ def get_current_parameter(tag=None):
return _params['parameters'][tag] return _params['parameters'][tag]
def get_experiment_id(): def get_experiment_id():
"""
Get experiment ID.
Returns
-------
str
Identifier of current experiment
"""
return _experiment_id return _experiment_id
def get_trial_id(): def get_trial_id():
"""
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial
job has an unique string ID.
Returns
-------
str
Identifier of current trial job which is calling this API.
"""
return _trial_id return _trial_id
def get_sequence_id(): def get_sequence_id():
"""
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
are unique for each trial job, they are of different data types.
Returns
-------
int
Sequence number of current trial job which is calling this API.
"""
return _sequence_id return _sequence_id
_intermediate_seq = 0 _intermediate_seq = 0
def report_intermediate_result(metric): def report_intermediate_result(metric):
"""Reports intermediate result to Assessor. """
metric: serializable object. Reports intermediate result to NNI.
Parameters
----------
metric:
serializable object.
""" """
global _intermediate_seq global _intermediate_seq
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result' assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result'
...@@ -88,8 +139,13 @@ def report_intermediate_result(metric): ...@@ -88,8 +139,13 @@ def report_intermediate_result(metric):
def report_final_result(metric): def report_final_result(metric):
"""Reports final result to tuner. """
metric: serializable object. Reports final result to NNI.
Parameters
----------
metric:
serializable object.
""" """
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result' assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment