Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3efc59ee
Unverified
Commit
3efc59ee
authored
May 11, 2020
by
QuanluZhang
Committed by
GitHub
May 11, 2020
Browse files
improve PBT tuner (#2357)
parent
7e35d32e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
41 deletions
+189
-41
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
+132
-41
src/sdk/pynni/tests/test_builtin_tuners.py
src/sdk/pynni/tests/test_builtin_tuners.py
+57
-0
No files found.
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
View file @
3efc59ee
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
random_state
=
np
.
random
.
RandomState
()
random_state
=
np
.
random
.
RandomState
()
hyper_parameters
[
'load_checkpoint_dir'
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
for
key
in
hyper_parameters
.
keys
():
for
key
in
hyper_parameters
.
keys
():
hyper_parameter
=
hyper_parameters
[
key
]
hyper_parameter
=
hyper_parameters
[
key
]
if
key
==
'load_checkpoint_dir'
:
if
key
==
'load_checkpoint_dir'
or
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
continue
elif
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
continue
continue
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
choices
=
search_space
[
key
][
"_value"
]
choices
=
search_space
[
key
][
"_value"
]
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
+
1
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
)
+
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
-
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
)
-
1
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
ub
-=
1
ub
-=
1
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else
:
else
:
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
continue
continue
if
search_space
[
key
][
"_type"
]
==
"choice"
:
if
search_space
[
key
][
"_type"
]
==
"choice"
:
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
hyper_parameters
[
key
]
=
{
'_index'
:
idx
,
'_value'
:
choices
[
idx
]
}
hyper_parameters
[
key
]
=
choices
[
idx
]
else
:
else
:
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
for
i
in
range
(
self
.
population_size
):
for
i
in
range
(
self
.
population_size
):
hyper_parameters
=
json2parameter
(
hyper_parameters
=
json2parameter
(
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
hyper_parameters
=
split_index
(
hyper_parameters
)
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
...
@@ -294,7 +294,42 @@ class PBTTuner(Tuner):
...
@@ -294,7 +294,42 @@ class PBTTuner(Tuner):
trial_info
.
parameter_id
=
parameter_id
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
running
[
parameter_id
]
=
trial_info
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
return
split_index
(
trial_info
.
hyper_parameters
)
return
trial_info
.
hyper_parameters
def
_proceed_next_epoch
(
self
):
"""
"""
logger
.
info
(
'Proceeding to next epoch'
)
self
.
epoch
+=
1
self
.
population
=
[]
self
.
pos
=
-
1
self
.
running
=
{}
#exploit and explore
reverse
=
True
if
self
.
optimize_mode
==
OptimizeMode
.
Maximize
else
False
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
reverse
)
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
tops
=
self
.
finished
[:
cutoff
]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
for
bottom
in
bottoms
:
top
=
np
.
random
.
choice
(
tops
)
exploit_and_explore
(
bottom
,
top
,
self
.
factor
,
self
.
resample_probability
,
self
.
epoch
,
self
.
searchspace_json
)
for
trial
in
self
.
finished
:
if
trial
not
in
bottoms
:
trial
.
clean_id
()
trial
.
hyper_parameters
[
'load_checkpoint_dir'
]
=
trial
.
hyper_parameters
[
'save_checkpoint_dir'
]
trial
.
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
trial
.
checkpoint_dir
,
str
(
self
.
epoch
))
self
.
finished_trials
=
0
for
_
in
range
(
self
.
population_size
):
trial_info
=
self
.
finished
.
pop
()
self
.
population
.
append
(
trial_info
)
while
self
.
credit
>
0
and
self
.
pos
+
1
<
len
(
self
.
population
):
self
.
credit
-=
1
self
.
pos
+=
1
parameter_id
=
self
.
param_ids
.
pop
()
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
send_trial_callback
(
parameter_id
,
trial_info
.
hyper_parameters
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
"""
...
@@ -312,43 +347,99 @@ class PBTTuner(Tuner):
...
@@ -312,43 +347,99 @@ class PBTTuner(Tuner):
"""
"""
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
value
=
extract_scalar_reward
(
value
)
value
=
extract_scalar_reward
(
value
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
self
.
_proceed_next_epoch
()
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""
Deal with trial failure
Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if
success
:
return
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
-
value
value
=
float
(
'inf'
)
else
:
value
=
float
(
'-inf'
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
if
self
.
finished_trials
==
self
.
population_size
:
logger
.
info
(
'Proceeding to next epoch'
)
self
.
_proceed_next_epoch
()
self
.
epoch
+=
1
self
.
population
=
[]
self
.
pos
=
-
1
self
.
running
=
{}
#exploit and explore
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
tops
=
self
.
finished
[:
cutoff
]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
for
bottom
in
bottoms
:
top
=
np
.
random
.
choice
(
tops
)
exploit_and_explore
(
bottom
,
top
,
self
.
factor
,
self
.
resample_probability
,
self
.
epoch
,
self
.
searchspace_json
)
for
trial
in
self
.
finished
:
if
trial
not
in
bottoms
:
trial
.
clean_id
()
trial
.
hyper_parameters
[
'load_checkpoint_dir'
]
=
trial
.
hyper_parameters
[
'save_checkpoint_dir'
]
trial
.
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
trial
.
checkpoint_dir
,
str
(
self
.
epoch
))
self
.
finished_trials
=
0
for
_
in
range
(
self
.
population_size
):
trial_info
=
self
.
finished
.
pop
()
self
.
population
.
append
(
trial_info
)
while
self
.
credit
>
0
and
self
.
pos
+
1
<
len
(
self
.
population
):
self
.
credit
-=
1
self
.
pos
+=
1
parameter_id
=
self
.
param_ids
.
pop
()
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
send_trial_callback
(
parameter_id
,
split_index
(
trial_info
.
hyper_parameters
))
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
pass
"""
Parameters
----------
data : json obj
imported data records
Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if
self
.
running
:
logger
.
warning
(
"Do not support importing data in the middle of experiment"
)
return
# the following is for experiment resume
_completed_num
=
0
epoch_data_dict
=
{}
for
trial_info
in
data
:
logger
.
info
(
"Process data record %s / %s"
,
_completed_num
,
len
(
data
))
_completed_num
+=
1
# simply validate data format
_params
=
trial_info
[
"parameter"
]
_value
=
trial_info
[
'value'
]
# assign fake value for failed trials
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
,
_value
)
_value
=
float
(
'inf'
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
else
float
(
'-inf'
)
_value
=
extract_scalar_reward
(
_value
)
if
'save_checkpoint_dir'
not
in
_params
:
logger
.
warning
(
"Invalid data record: save_checkpoint_dir is missing, abandon data import."
)
return
epoch_num
=
int
(
os
.
path
.
basename
(
_params
[
'save_checkpoint_dir'
]))
if
epoch_num
not
in
epoch_data_dict
:
epoch_data_dict
[
epoch_num
]
=
[]
epoch_data_dict
[
epoch_num
].
append
((
_params
,
_value
))
if
not
epoch_data_dict
:
logger
.
warning
(
"No valid epochs, abandon data import."
)
return
# figure out start epoch for resume
max_epoch_num
=
max
(
epoch_data_dict
,
key
=
int
)
if
len
(
epoch_data_dict
[
max_epoch_num
])
<
self
.
population_size
:
max_epoch_num
-=
1
# If there is no a single complete round, no data to import, start from scratch
if
max_epoch_num
<
0
:
logger
.
warning
(
"No completed epoch, abandon data import."
)
return
assert
len
(
epoch_data_dict
[
max_epoch_num
])
==
self
.
population_size
# check existence of trial save checkpoint dir
for
params
,
_
in
epoch_data_dict
[
max_epoch_num
]:
if
not
os
.
path
.
isdir
(
params
[
'save_checkpoint_dir'
]):
logger
.
warning
(
"save_checkpoint_dir %s does not exist, data will not be resumed"
,
params
[
'save_checkpoint_dir'
])
return
# resume data
self
.
epoch
=
max_epoch_num
self
.
finished_trials
=
self
.
population_size
for
params
,
value
in
epoch_data_dict
[
max_epoch_num
]:
checkpoint_dir
=
os
.
path
.
dirname
(
params
[
'save_checkpoint_dir'
])
self
.
finished
.
append
(
TrialInfo
(
checkpoint_dir
=
checkpoint_dir
,
hyper_parameters
=
params
,
score
=
value
))
self
.
_proceed_next_epoch
()
logger
.
info
(
"Successfully import data to PBT tuner, total data: %d, imported data: %d."
,
len
(
data
),
self
.
population_size
)
logger
.
info
(
"Start from epoch %d ..."
,
self
.
epoch
)
return
self
.
epoch
# return for test
src/sdk/pynni/tests/test_builtin_tuners.py
View file @
3efc59ee
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
def
import_data_test_for_pbt
(
self
):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space
=
{
"choice_str"
:
{
"_type"
:
"choice"
,
"_value"
:
[
"cat"
,
"dog"
,
"elephant"
,
"cow"
,
"sheep"
,
"panda"
]
}
}
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
)
population_size
=
4
# ===import data at the beginning===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
save_dirs
=
[
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
0
))
for
i
in
range
(
population_size
)]
# create save checkpoint directory
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# for simplicity, omit "load_checkpoint_dir"
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
for
i
in
range
(
population_size
-
1
):
save_dirs
.
append
(
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
1
)))
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
4
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
5
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
6
]},
"value"
:
11
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning with incomplete epoch"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
"""
"""
import data at the beginning with number value and dict value
import data at the beginning with number value and dict value
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
population_size
=
100
population_size
=
100
))
))
self
.
import_data_test_for_pbt
()
def
tearDown
(
self
):
def
tearDown
(
self
):
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment