Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3efc59ee
Unverified
Commit
3efc59ee
authored
May 11, 2020
by
QuanluZhang
Committed by
GitHub
May 11, 2020
Browse files
improve PBT tuner (#2357)
parent
7e35d32e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
41 deletions
+189
-41
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
+132
-41
src/sdk/pynni/tests/test_builtin_tuners.py
src/sdk/pynni/tests/test_builtin_tuners.py
+57
-0
No files found.
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
View file @
3efc59ee
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
random_state
=
np
.
random
.
RandomState
()
random_state
=
np
.
random
.
RandomState
()
hyper_parameters
[
'load_checkpoint_dir'
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
for
key
in
hyper_parameters
.
keys
():
for
key
in
hyper_parameters
.
keys
():
hyper_parameter
=
hyper_parameters
[
key
]
hyper_parameter
=
hyper_parameters
[
key
]
if
key
==
'load_checkpoint_dir'
:
if
key
==
'load_checkpoint_dir'
or
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
continue
elif
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
continue
continue
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
choices
=
search_space
[
key
][
"_value"
]
choices
=
search_space
[
key
][
"_value"
]
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
+
1
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
)
+
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
-
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
)
-
1
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
ub
-=
1
ub
-=
1
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else
:
else
:
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
continue
continue
if
search_space
[
key
][
"_type"
]
==
"choice"
:
if
search_space
[
key
][
"_type"
]
==
"choice"
:
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
hyper_parameters
[
key
]
=
{
'_index'
:
idx
,
'_value'
:
choices
[
idx
]
}
hyper_parameters
[
key
]
=
choices
[
idx
]
else
:
else
:
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
for
i
in
range
(
self
.
population_size
):
for
i
in
range
(
self
.
population_size
):
hyper_parameters
=
json2parameter
(
hyper_parameters
=
json2parameter
(
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
hyper_parameters
=
split_index
(
hyper_parameters
)
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
...
@@ -294,38 +294,19 @@ class PBTTuner(Tuner):
...
@@ -294,38 +294,19 @@ class PBTTuner(Tuner):
trial_info
.
parameter_id
=
parameter_id
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
running
[
parameter_id
]
=
trial_info
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
return
split_index
(
trial_info
.
hyper_parameters
)
return
trial_info
.
hyper_parameters
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
def
_proceed_next_epoch
(
self
):
"""
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
"""
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
value
=
extract_scalar_reward
(
value
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
-
value
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
logger
.
info
(
'Proceeding to next epoch'
)
logger
.
info
(
'Proceeding to next epoch'
)
self
.
epoch
+=
1
self
.
epoch
+=
1
self
.
population
=
[]
self
.
population
=
[]
self
.
pos
=
-
1
self
.
pos
=
-
1
self
.
running
=
{}
self
.
running
=
{}
#exploit and explore
#exploit and explore
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
reverse
=
True
if
self
.
optimize_mode
==
OptimizeMode
.
Maximize
else
False
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
reverse
)
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
tops
=
self
.
finished
[:
cutoff
]
tops
=
self
.
finished
[:
cutoff
]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
...
@@ -348,7 +329,117 @@ class PBTTuner(Tuner):
...
@@ -348,7 +329,117 @@ class PBTTuner(Tuner):
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
.
parameter_id
=
parameter_id
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
running
[
parameter_id
]
=
trial_info
self
.
send_trial_callback
(
parameter_id
,
split_index
(
trial_info
.
hyper_parameters
))
self
.
send_trial_callback
(
parameter_id
,
trial_info
.
hyper_parameters
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
value
=
extract_scalar_reward
(
value
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
self
.
_proceed_next_epoch
()
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""
Deal with trial failure
Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if
success
:
return
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
float
(
'inf'
)
else
:
value
=
float
(
'-inf'
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
self
.
_proceed_next_epoch
()
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
pass
"""
Parameters
----------
data : json obj
imported data records
Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if
self
.
running
:
logger
.
warning
(
"Do not support importing data in the middle of experiment"
)
return
# the following is for experiment resume
_completed_num
=
0
epoch_data_dict
=
{}
for
trial_info
in
data
:
logger
.
info
(
"Process data record %s / %s"
,
_completed_num
,
len
(
data
))
_completed_num
+=
1
# simply validate data format
_params
=
trial_info
[
"parameter"
]
_value
=
trial_info
[
'value'
]
# assign fake value for failed trials
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
,
_value
)
_value
=
float
(
'inf'
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
else
float
(
'-inf'
)
_value
=
extract_scalar_reward
(
_value
)
if
'save_checkpoint_dir'
not
in
_params
:
logger
.
warning
(
"Invalid data record: save_checkpoint_dir is missing, abandon data import."
)
return
epoch_num
=
int
(
os
.
path
.
basename
(
_params
[
'save_checkpoint_dir'
]))
if
epoch_num
not
in
epoch_data_dict
:
epoch_data_dict
[
epoch_num
]
=
[]
epoch_data_dict
[
epoch_num
].
append
((
_params
,
_value
))
if
not
epoch_data_dict
:
logger
.
warning
(
"No valid epochs, abandon data import."
)
return
# figure out start epoch for resume
max_epoch_num
=
max
(
epoch_data_dict
,
key
=
int
)
if
len
(
epoch_data_dict
[
max_epoch_num
])
<
self
.
population_size
:
max_epoch_num
-=
1
# If there is no a single complete round, no data to import, start from scratch
if
max_epoch_num
<
0
:
logger
.
warning
(
"No completed epoch, abandon data import."
)
return
assert
len
(
epoch_data_dict
[
max_epoch_num
])
==
self
.
population_size
# check existence of trial save checkpoint dir
for
params
,
_
in
epoch_data_dict
[
max_epoch_num
]:
if
not
os
.
path
.
isdir
(
params
[
'save_checkpoint_dir'
]):
logger
.
warning
(
"save_checkpoint_dir %s does not exist, data will not be resumed"
,
params
[
'save_checkpoint_dir'
])
return
# resume data
self
.
epoch
=
max_epoch_num
self
.
finished_trials
=
self
.
population_size
for
params
,
value
in
epoch_data_dict
[
max_epoch_num
]:
checkpoint_dir
=
os
.
path
.
dirname
(
params
[
'save_checkpoint_dir'
])
self
.
finished
.
append
(
TrialInfo
(
checkpoint_dir
=
checkpoint_dir
,
hyper_parameters
=
params
,
score
=
value
))
self
.
_proceed_next_epoch
()
logger
.
info
(
"Successfully import data to PBT tuner, total data: %d, imported data: %d."
,
len
(
data
),
self
.
population_size
)
logger
.
info
(
"Start from epoch %d ..."
,
self
.
epoch
)
return
self
.
epoch
# return for test
src/sdk/pynni/tests/test_builtin_tuners.py
View file @
3efc59ee
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
def
import_data_test_for_pbt
(
self
):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space
=
{
"choice_str"
:
{
"_type"
:
"choice"
,
"_value"
:
[
"cat"
,
"dog"
,
"elephant"
,
"cow"
,
"sheep"
,
"panda"
]
}
}
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
)
population_size
=
4
# ===import data at the beginning===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
save_dirs
=
[
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
0
))
for
i
in
range
(
population_size
)]
# create save checkpoint directory
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# for simplicity, omit "load_checkpoint_dir"
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
for
i
in
range
(
population_size
-
1
):
save_dirs
.
append
(
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
1
)))
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
4
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
5
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
6
]},
"value"
:
11
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning with incomplete epoch"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
"""
"""
import data at the beginning with number value and dict value
import data at the beginning with number value and dict value
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
population_size
=
100
population_size
=
100
))
))
self
.
import_data_test_for_pbt
()
def
tearDown
(
self
):
def
tearDown
(
self
):
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment