Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8735fa58
"src/vscode:/vscode.git/clone" did not exist on "5e5c27a63b1637556a17e17546147da6cb6d732e"
Unverified
Commit
8735fa58
authored
Nov 11, 2019
by
chicm-ms
Committed by
GitHub
Nov 11, 2019
Browse files
Update trial doc string (#1713)
* trial docstring
parent
f5803f68
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
6 deletions
+62
-6
src/sdk/pynni/nni/trial.py
src/sdk/pynni/nni/trial.py
+62
-6
No files found.
src/sdk/pynni/nni/trial.py
View file @
8735fa58
...
@@ -43,8 +43,18 @@ _sequence_id = platform.get_sequence_id()
...
@@ -43,8 +43,18 @@ _sequence_id = platform.get_sequence_id()
def
get_next_parameter
():
def
get_next_parameter
():
"""Returns a set of (hyper-)paremeters generated by Tuner.
"""
Returns None if no more (hyper-)parameters can be generated by Tuner."""
Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call.
This API should be called only once in each trial job of an experiment which is not specified as multiphase.
Returns
-------
dict
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in
search space. Returns None if no more hyper parameters can be generated by tuner.
"""
global
_params
global
_params
_params
=
platform
.
get_next_parameter
()
_params
=
platform
.
get_next_parameter
()
if
_params
is
None
:
if
_params
is
None
:
...
@@ -52,6 +62,15 @@ def get_next_parameter():
...
@@ -52,6 +62,15 @@ def get_next_parameter():
return
_params
[
'parameters'
]
return
_params
[
'parameters'
]
def
get_current_parameter
(
tag
=
None
):
def
get_current_parameter
(
tag
=
None
):
"""
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
call of get_next_parameter returns.
Parameters
----------
tag: str
hyper parameter key
"""
global
_params
global
_params
if
_params
is
None
:
if
_params
is
None
:
return
None
return
None
...
@@ -60,19 +79,51 @@ def get_current_parameter(tag=None):
...
@@ -60,19 +79,51 @@ def get_current_parameter(tag=None):
return
_params
[
'parameters'
][
tag
]
return
_params
[
'parameters'
][
tag
]
def
get_experiment_id
():
def
get_experiment_id
():
"""
Get experiment ID.
Returns
-------
str
Identifier of current experiment
"""
return
_experiment_id
return
_experiment_id
def
get_trial_id
():
def
get_trial_id
():
"""
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial
job has an unique string ID.
Returns
-------
str
Identifier of current trial job which is calling this API.
"""
return
_trial_id
return
_trial_id
def
get_sequence_id
():
def
get_sequence_id
():
"""
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
are unique for each trial job, they are of different data types.
Returns
-------
int
Sequence number of current trial job which is calling this API.
"""
return
_sequence_id
return
_sequence_id
_intermediate_seq
=
0
_intermediate_seq
=
0
def
report_intermediate_result
(
metric
):
def
report_intermediate_result
(
metric
):
"""Reports intermediate result to Assessor.
"""
metric: serializable object.
Reports intermediate result to NNI.
Parameters
----------
metric:
serializable object.
"""
"""
global
_intermediate_seq
global
_intermediate_seq
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_intermediate_result'
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_intermediate_result'
...
@@ -88,8 +139,13 @@ def report_intermediate_result(metric):
...
@@ -88,8 +139,13 @@ def report_intermediate_result(metric):
def
report_final_result
(
metric
):
def
report_final_result
(
metric
):
"""Reports final result to tuner.
"""
metric: serializable object.
Reports final result to NNI.
Parameters
----------
metric:
serializable object.
"""
"""
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_final_result'
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_final_result'
metric
=
json_tricks
.
dumps
({
metric
=
json_tricks
.
dumps
({
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment