Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
07655bf4
Commit
07655bf4
authored
Apr 10, 2019
by
Shufan Huang
Committed by
QuanluZhang
Apr 10, 2019
Browse files
Delete “return True“ that is no longer used during dispatcher call (#960)
remove return True
parent
1362de02
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
24 deletions
+5
-24
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+0
-11
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+4
-11
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+1
-2
No files found.
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
07655bf4
...
@@ -333,7 +333,6 @@ class Hyperband(MsgDispatcherBase):
...
@@ -333,7 +333,6 @@ class Hyperband(MsgDispatcherBase):
"""
"""
self
.
handle_update_search_space
(
data
)
self
.
handle_update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
"""
"""
...
@@ -345,8 +344,6 @@ class Hyperband(MsgDispatcherBase):
...
@@ -345,8 +344,6 @@ class Hyperband(MsgDispatcherBase):
for
_
in
range
(
data
):
for
_
in
range
(
data
):
self
.
_request_one_trial_job
()
self
.
_request_one_trial_job
()
return
True
def
_request_one_trial_job
(
self
):
def
_request_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration."""
"""get one trial job, i.e., one hyperparameter configuration."""
if
not
self
.
generated_hyper_configs
:
if
not
self
.
generated_hyper_configs
:
...
@@ -372,8 +369,6 @@ class Hyperband(MsgDispatcherBase):
...
@@ -372,8 +369,6 @@ class Hyperband(MsgDispatcherBase):
}
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
True
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
"""data: JSON object, which is search space
...
@@ -385,8 +380,6 @@ class Hyperband(MsgDispatcherBase):
...
@@ -385,8 +380,6 @@ class Hyperband(MsgDispatcherBase):
self
.
searchspace_json
=
data
self
.
searchspace_json
=
data
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
return
True
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""
"""
Parameters
Parameters
...
@@ -415,8 +408,6 @@ class Hyperband(MsgDispatcherBase):
...
@@ -415,8 +408,6 @@ class Hyperband(MsgDispatcherBase):
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
self
.
credit
-=
1
return
True
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
Parameters
Parameters
...
@@ -442,7 +433,5 @@ class Hyperband(MsgDispatcherBase):
...
@@ -442,7 +433,5 @@ class Hyperband(MsgDispatcherBase):
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
return
True
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
pass
src/sdk/pynni/nni/msg_dispatcher.py
View file @
07655bf4
...
@@ -92,7 +92,6 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -92,7 +92,6 @@ class MsgDispatcher(MsgDispatcherBase):
"""
"""
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
# data: number or trial jobs
...
@@ -105,18 +104,15 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -105,18 +104,15 @@ class MsgDispatcher(MsgDispatcherBase):
# when parameters is None.
# when parameters is None.
if
len
(
params_list
)
<
len
(
ids
):
if
len
(
params_list
)
<
len
(
ids
):
send
(
CommandType
.
NoMoreTrialJobs
,
_pack_parameter
(
ids
[
0
],
''
))
send
(
CommandType
.
NoMoreTrialJobs
,
_pack_parameter
(
ids
[
0
],
''
))
return
True
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
# data: parameters
id_
=
_create_parameter_id
()
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
return
True
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
...
@@ -135,8 +131,6 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -135,8 +131,6 @@ class MsgDispatcher(MsgDispatcherBase):
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
return
True
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""
"""
data: it has three keys: trial_job_id, event, hyper_params
data: it has three keys: trial_job_id, event, hyper_params
...
@@ -152,7 +146,6 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -152,7 +146,6 @@ class MsgDispatcher(MsgDispatcherBase):
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
)
return
True
def
_handle_final_metric_data
(
self
,
data
):
def
_handle_final_metric_data
(
self
,
data
):
"""Call tuner to process final results
"""Call tuner to process final results
...
@@ -168,19 +161,19 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -168,19 +161,19 @@ class MsgDispatcher(MsgDispatcherBase):
"""Call assessor to process intermediate results
"""Call assessor to process intermediate results
"""
"""
if
data
[
'type'
]
!=
'PERIODICAL'
:
if
data
[
'type'
]
!=
'PERIODICAL'
:
return
True
return
if
self
.
assessor
is
None
:
if
self
.
assessor
is
None
:
return
True
return
trial_job_id
=
data
[
'trial_job_id'
]
trial_job_id
=
data
[
'trial_job_id'
]
if
trial_job_id
in
_ended_trials
:
if
trial_job_id
in
_ended_trials
:
return
True
return
history
=
_trial_history
[
trial_job_id
]
history
=
_trial_history
[
trial_job_id
]
history
[
data
[
'sequence'
]]
=
data
[
'value'
]
history
[
data
[
'sequence'
]]
=
data
[
'value'
]
ordered_history
=
_sort_history
(
history
)
ordered_history
=
_sort_history
(
history
)
if
len
(
ordered_history
)
<
data
[
'sequence'
]:
# no user-visible update since last time
if
len
(
ordered_history
)
<
data
[
'sequence'
]:
# no user-visible update since last time
return
True
return
try
:
try
:
result
=
self
.
assessor
.
assess_trial
(
trial_job_id
,
ordered_history
)
result
=
self
.
assessor
.
assess_trial
(
trial_job_id
,
ordered_history
)
...
...
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
07655bf4
...
@@ -152,8 +152,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -152,8 +152,7 @@ class MsgDispatcherBase(Recoverable):
}
}
if
command
not
in
command_handlers
:
if
command
not
in
command_handlers
:
raise
AssertionError
(
'Unsupported command: {}'
.
format
(
command
))
raise
AssertionError
(
'Unsupported command: {}'
.
format
(
command
))
command_handlers
[
command
](
data
)
return
command_handlers
[
command
](
data
)
def
handle_ping
(
self
,
data
):
def
handle_ping
(
self
,
data
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment