Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3efc59ee
Unverified
Commit
3efc59ee
authored
May 11, 2020
by
QuanluZhang
Committed by
GitHub
May 11, 2020
Browse files
improve PBT tuner (#2357)
parent
7e35d32e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
41 deletions
+189
-41
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
+132
-41
src/sdk/pynni/tests/test_builtin_tuners.py
src/sdk/pynni/tests/test_builtin_tuners.py
+57
-0
No files found.
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
View file @
3efc59ee
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
random_state
=
np
.
random
.
RandomState
()
random_state
=
np
.
random
.
RandomState
()
hyper_parameters
[
'load_checkpoint_dir'
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
for
key
in
hyper_parameters
.
keys
():
for
key
in
hyper_parameters
.
keys
():
hyper_parameter
=
hyper_parameters
[
key
]
hyper_parameter
=
hyper_parameters
[
key
]
if
key
==
'load_checkpoint_dir'
:
if
key
==
'load_checkpoint_dir'
or
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
continue
elif
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
continue
continue
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
choices
=
search_space
[
key
][
"_value"
]
choices
=
search_space
[
key
][
"_value"
]
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
+
1
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
)
+
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
[
"_value"
]
)
-
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
)
-
1
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
ub
-=
1
ub
-=
1
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
...
@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else
:
else
:
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
continue
continue
if
search_space
[
key
][
"_type"
]
==
"choice"
:
if
search_space
[
key
][
"_type"
]
==
"choice"
:
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
hyper_parameters
[
key
]
=
{
'_index'
:
idx
,
'_value'
:
choices
[
idx
]
}
hyper_parameters
[
key
]
=
choices
[
idx
]
else
:
else
:
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
...
@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
for
i
in
range
(
self
.
population_size
):
for
i
in
range
(
self
.
population_size
):
hyper_parameters
=
json2parameter
(
hyper_parameters
=
json2parameter
(
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
hyper_parameters
=
split_index
(
hyper_parameters
)
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
checkpoint_dir
=
os
.
path
.
join
(
self
.
all_checkpoint_dir
,
str
(
i
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'load_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
hyper_parameters
[
'save_checkpoint_dir'
]
=
os
.
path
.
join
(
checkpoint_dir
,
str
(
self
.
epoch
))
...
@@ -294,38 +294,19 @@ class PBTTuner(Tuner):
...
@@ -294,38 +294,19 @@ class PBTTuner(Tuner):
trial_info
.
parameter_id
=
parameter_id
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
running
[
parameter_id
]
=
trial_info
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
logger
.
info
(
'Generate parameter : %s'
,
trial_info
.
hyper_parameters
)
return
split_index
(
trial_info
.
hyper_parameters
)
return
trial_info
.
hyper_parameters
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
def
_proceed_next_epoch
(
self
):
"""
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
"""
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
value
=
extract_scalar_reward
(
value
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
-
value
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
logger
.
info
(
'Proceeding to next epoch'
)
logger
.
info
(
'Proceeding to next epoch'
)
self
.
epoch
+=
1
self
.
epoch
+=
1
self
.
population
=
[]
self
.
population
=
[]
self
.
pos
=
-
1
self
.
pos
=
-
1
self
.
running
=
{}
self
.
running
=
{}
#exploit and explore
#exploit and explore
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
reverse
=
True
if
self
.
optimize_mode
==
OptimizeMode
.
Maximize
else
False
self
.
finished
=
sorted
(
self
.
finished
,
key
=
lambda
x
:
x
.
score
,
reverse
=
reverse
)
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
cutoff
=
int
(
np
.
ceil
(
self
.
fraction
*
len
(
self
.
finished
)))
tops
=
self
.
finished
[:
cutoff
]
tops
=
self
.
finished
[:
cutoff
]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
...
@@ -348,7 +329,117 @@ class PBTTuner(Tuner):
...
@@ -348,7 +329,117 @@ class PBTTuner(Tuner):
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
=
self
.
population
[
self
.
pos
]
trial_info
.
parameter_id
=
parameter_id
trial_info
.
parameter_id
=
parameter_id
self
.
running
[
parameter_id
]
=
trial_info
self
.
running
[
parameter_id
]
=
trial_info
self
.
send_trial_callback
(
parameter_id
,
split_index
(
trial_info
.
hyper_parameters
))
self
.
send_trial_callback
(
parameter_id
,
trial_info
.
hyper_parameters
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
logger
.
info
(
'Get one trial result, id = %d, value = %s'
,
parameter_id
,
value
)
value
=
extract_scalar_reward
(
value
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
self
.
_proceed_next_epoch
()
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""
Deal with trial failure
Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if
success
:
return
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
float
(
'inf'
)
else
:
value
=
float
(
'-inf'
)
trial_info
=
self
.
running
.
pop
(
parameter_id
,
None
)
trial_info
.
score
=
value
self
.
finished
.
append
(
trial_info
)
self
.
finished_trials
+=
1
if
self
.
finished_trials
==
self
.
population_size
:
self
.
_proceed_next_epoch
()
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
pass
"""
Parameters
----------
data : json obj
imported data records
Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if
self
.
running
:
logger
.
warning
(
"Do not support importing data in the middle of experiment"
)
return
# the following is for experiment resume
_completed_num
=
0
epoch_data_dict
=
{}
for
trial_info
in
data
:
logger
.
info
(
"Process data record %s / %s"
,
_completed_num
,
len
(
data
))
_completed_num
+=
1
# simply validate data format
_params
=
trial_info
[
"parameter"
]
_value
=
trial_info
[
'value'
]
# assign fake value for failed trials
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
,
_value
)
_value
=
float
(
'inf'
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
else
float
(
'-inf'
)
_value
=
extract_scalar_reward
(
_value
)
if
'save_checkpoint_dir'
not
in
_params
:
logger
.
warning
(
"Invalid data record: save_checkpoint_dir is missing, abandon data import."
)
return
epoch_num
=
int
(
os
.
path
.
basename
(
_params
[
'save_checkpoint_dir'
]))
if
epoch_num
not
in
epoch_data_dict
:
epoch_data_dict
[
epoch_num
]
=
[]
epoch_data_dict
[
epoch_num
].
append
((
_params
,
_value
))
if
not
epoch_data_dict
:
logger
.
warning
(
"No valid epochs, abandon data import."
)
return
# figure out start epoch for resume
max_epoch_num
=
max
(
epoch_data_dict
,
key
=
int
)
if
len
(
epoch_data_dict
[
max_epoch_num
])
<
self
.
population_size
:
max_epoch_num
-=
1
# If there is no a single complete round, no data to import, start from scratch
if
max_epoch_num
<
0
:
logger
.
warning
(
"No completed epoch, abandon data import."
)
return
assert
len
(
epoch_data_dict
[
max_epoch_num
])
==
self
.
population_size
# check existence of trial save checkpoint dir
for
params
,
_
in
epoch_data_dict
[
max_epoch_num
]:
if
not
os
.
path
.
isdir
(
params
[
'save_checkpoint_dir'
]):
logger
.
warning
(
"save_checkpoint_dir %s does not exist, data will not be resumed"
,
params
[
'save_checkpoint_dir'
])
return
# resume data
self
.
epoch
=
max_epoch_num
self
.
finished_trials
=
self
.
population_size
for
params
,
value
in
epoch_data_dict
[
max_epoch_num
]:
checkpoint_dir
=
os
.
path
.
dirname
(
params
[
'save_checkpoint_dir'
])
self
.
finished
.
append
(
TrialInfo
(
checkpoint_dir
=
checkpoint_dir
,
hyper_parameters
=
params
,
score
=
value
))
self
.
_proceed_next_epoch
()
logger
.
info
(
"Successfully import data to PBT tuner, total data: %d, imported data: %d."
,
len
(
data
),
self
.
population_size
)
logger
.
info
(
"Start from epoch %d ..."
,
self
.
epoch
)
return
self
.
epoch
# return for test
src/sdk/pynni/tests/test_builtin_tuners.py
View file @
3efc59ee
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
logger
.
info
(
"Full supported search space: %s"
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
self
.
search_space_test_one
(
tuner_factory
,
full_supported_search_space
)
def
import_data_test_for_pbt
(
self
):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space
=
{
"choice_str"
:
{
"_type"
:
"choice"
,
"_value"
:
[
"cat"
,
"dog"
,
"elephant"
,
"cow"
,
"sheep"
,
"panda"
]
}
}
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
)
population_size
=
4
# ===import data at the beginning===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
save_dirs
=
[
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
0
))
for
i
in
range
(
population_size
)]
# create save checkpoint directory
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
# for simplicity, omit "load_checkpoint_dir"
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner
=
PBTTuner
(
all_checkpoint_dir
=
all_checkpoint_dir
,
population_size
=
population_size
)
self
.
assertIsInstance
(
tuner
,
Tuner
)
tuner
.
update_search_space
(
search_space
)
for
i
in
range
(
population_size
-
1
):
save_dirs
.
append
(
os
.
path
.
join
(
all_checkpoint_dir
,
str
(
i
),
str
(
1
)))
for
save_dir
in
save_dirs
:
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
data
=
[{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
0
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
1
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
2
]},
"value"
:
11
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
3
]},
"value"
:
7
},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
4
]},
"value"
:
1.1
},
{
"parameter"
:
{
"choice_str"
:
"dog"
,
"save_checkpoint_dir"
:
save_dirs
[
5
]},
"value"
:
{
"default"
:
1.2
,
"tmp"
:
2
}},
{
"parameter"
:
{
"choice_str"
:
"cat"
,
"save_checkpoint_dir"
:
save_dirs
[
6
]},
"value"
:
11
}]
epoch
=
tuner
.
import_data
(
data
)
self
.
assertEqual
(
epoch
,
1
)
logger
.
info
(
"Imported data successfully at the beginning with incomplete epoch"
)
shutil
.
rmtree
(
all_checkpoint_dir
)
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
def
import_data_test
(
self
,
tuner_factory
,
stype
=
"choice_str"
):
"""
"""
import data at the beginning with number value and dict value
import data at the beginning with number value and dict value
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
all_checkpoint_dir
=
os
.
path
.
expanduser
(
"~/nni/checkpoint/test/"
),
population_size
=
100
population_size
=
100
))
))
self
.
import_data_test_for_pbt
()
def
tearDown
(
self
):
def
tearDown
(
self
):
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
file_list
=
glob
.
glob
(
"smac3*"
)
+
[
"param_config_space.pcs"
,
"scenario.txt"
,
"model_path"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment