Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d2c57770
Unverified
Commit
d2c57770
authored
Apr 08, 2020
by
RayMeng8
Committed by
GitHub
Apr 08, 2020
Browse files
Add supported data types for PBT tuner (#2271)
parent
c61700f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
17 deletions
+111
-17
examples/trials/mnist-pbt-tuner-pytorch/__init__.py
examples/trials/mnist-pbt-tuner-pytorch/__init__.py
+0
-0
examples/trials/mnist-pbt-tuner-pytorch/mnist.py
examples/trials/mnist-pbt-tuner-pytorch/mnist.py
+2
-2
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
+109
-15
No files found.
examples/trials/mnist-pbt-tuner-pytorch/__init__.py
deleted
100644 → 0
View file @
c61700f3
examples/trials/mnist-pbt-tuner-pytorch/mnist.py
View file @
d2c57770
...
@@ -155,8 +155,8 @@ def get_params():
...
@@ -155,8 +155,8 @@ def get_params():
help
=
'learning rate (default: 0.01)'
)
help
=
'learning rate (default: 0.01)'
)
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.5
,
metavar
=
'M'
,
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.5
,
metavar
=
'M'
,
help
=
'SGD momentum (default: 0.5)'
)
help
=
'SGD momentum (default: 0.5)'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
1
0
,
metavar
=
'N'
,
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: 1
0
)'
)
help
=
'number of epochs to train (default: 1)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
...
...
src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
View file @
d2c57770
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
import
copy
import
copy
import
logging
import
logging
import
os
import
os
import
random
import
numpy
as
np
import
numpy
as
np
import
nni
import
nni
import
nni.parameter_expressions
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
,
split_index
,
json2parameter
,
json2space
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
,
split_index
,
json2parameter
,
json2space
...
@@ -14,7 +16,42 @@ from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2par
...
@@ -14,7 +16,42 @@ from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2par
logger
=
logging
.
getLogger
(
'pbt_tuner_AutoML'
)
logger
=
logging
.
getLogger
(
'pbt_tuner_AutoML'
)
def
exploit_and_explore
(
bot_trial_info
,
top_trial_info
,
factors
,
epoch
,
search_space
):
def
perturbation
(
hyperparameter_type
,
value
,
resample_probablity
,
uv
,
ub
,
lv
,
lb
,
random_state
):
"""
Perturbation for hyperparameters
Parameters
----------
hyperparameter_type : str
type of hyperparameter
value : list
parameters for sampling hyperparameter
resample_probability : float
probability for resampling
uv : float/int
upper value after perturbation
ub : float/int
upper bound
lv : float/int
lower value after perturbation
lb : float/int
lower bound
random_state : RandomState
random state
"""
if
random
.
random
()
<
resample_probablity
:
if
hyperparameter_type
==
"choice"
:
return
value
.
index
(
nni
.
parameter_expressions
.
choice
(
value
,
random_state
))
else
:
return
getattr
(
nni
.
parameter_expressions
,
hyperparameter_type
)(
*
(
value
+
[
random_state
]))
else
:
if
random
.
random
()
>
0.5
:
return
min
(
uv
,
ub
)
else
:
return
max
(
lv
,
lb
)
def
exploit_and_explore
(
bot_trial_info
,
top_trial_info
,
factor
,
resample_probability
,
epoch
,
search_space
):
"""
"""
Replace checkpoint of bot_trial with top, and perturb hyperparameters
Replace checkpoint of bot_trial with top, and perturb hyperparameters
...
@@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
...
@@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bottom model whose parameters should be replaced
bottom model whose parameters should be replaced
top_trial_info : TrialInfo
top_trial_info : TrialInfo
better model
better model
factors : float
factor : float
factors for perturbation
factor for perturbation
resample_probability : float
probability for resampling
epoch : int
epoch : int
step of PBTTuner
step of PBTTuner
search_space : dict
search_space : dict
...
@@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
...
@@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bot_checkpoint_dir
=
bot_trial_info
.
checkpoint_dir
bot_checkpoint_dir
=
bot_trial_info
.
checkpoint_dir
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
top_hyper_parameters
=
top_trial_info
.
hyper_parameters
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
hyper_parameters
=
copy
.
deepcopy
(
top_hyper_parameters
)
# TODO think about different type of hyperparameters for 1.perturbation 2.within search space
random_state
=
np
.
random
.
RandomState
()
for
key
in
hyper_parameters
.
keys
():
for
key
in
hyper_parameters
.
keys
():
hyper_parameter
=
hyper_parameters
[
key
]
if
key
==
'load_checkpoint_dir'
:
if
key
==
'load_checkpoint_dir'
:
hyper_parameters
[
key
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
hyper_parameters
[
key
]
=
hyper_parameters
[
'save_checkpoint_dir'
]
continue
elif
key
==
'save_checkpoint_dir'
:
elif
key
==
'save_checkpoint_dir'
:
hyper_parameters
[
key
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
hyper_parameters
[
key
]
=
os
.
path
.
join
(
bot_checkpoint_dir
,
str
(
epoch
))
elif
isinstance
(
hyper_parameters
[
key
],
float
):
continue
perturb
=
np
.
random
.
choice
(
factors
)
elif
search_space
[
key
][
"_type"
]
==
"choice"
:
val
=
hyper_parameters
[
key
]
*
perturb
choices
=
search_space
[
key
][
"_value"
]
ub
,
uv
=
len
(
choices
)
-
1
,
choices
.
index
(
hyper_parameter
[
"_value"
])
+
1
lb
,
lv
=
0
,
choices
.
index
(
hyper_parameter
[
"_value"
])
-
1
elif
search_space
[
key
][
"_type"
]
==
"randint"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
if
search_space
[
key
][
"_type"
]
in
(
"uniform"
,
"normal"
):
ub
-=
1
val
=
np
.
clip
(
val
,
lb
,
ub
).
item
()
uv
=
hyper_parameter
+
1
hyper_parameters
[
key
]
=
val
lv
=
hyper_parameter
-
1
elif
search_space
[
key
][
"_type"
]
==
"uniform"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
perturb
=
(
ub
-
lb
)
*
factor
uv
=
hyper_parameter
+
perturb
lv
=
hyper_parameter
-
perturb
elif
search_space
[
key
][
"_type"
]
==
"quniform"
:
lb
,
ub
,
q
=
search_space
[
key
][
"_value"
][:
3
]
multi
=
round
(
hyper_parameter
/
q
)
uv
=
(
multi
+
1
)
*
q
lv
=
(
multi
-
1
)
*
q
elif
search_space
[
key
][
"_type"
]
==
"loguniform"
:
lb
,
ub
=
search_space
[
key
][
"_value"
][:
2
]
perturb
=
(
np
.
log
(
ub
)
-
np
.
log
(
lb
))
*
factor
uv
=
np
.
exp
(
min
(
np
.
log
(
hyper_parameter
)
+
perturb
,
np
.
log
(
ub
)))
lv
=
np
.
exp
(
max
(
np
.
log
(
hyper_parameter
)
-
perturb
,
np
.
log
(
lb
)))
elif
search_space
[
key
][
"_type"
]
==
"qloguniform"
:
lb
,
ub
,
q
=
search_space
[
key
][
"_value"
][:
3
]
multi
=
round
(
hyper_parameter
/
q
)
uv
=
(
multi
+
1
)
*
q
lv
=
(
multi
-
1
)
*
q
elif
search_space
[
key
][
"_type"
]
==
"normal"
:
sigma
=
search_space
[
key
][
"_value"
][
1
]
perturb
=
sigma
*
factor
uv
=
ub
=
hyper_parameter
+
perturb
lv
=
lb
=
hyper_parameter
-
perturb
elif
search_space
[
key
][
"_type"
]
==
"qnormal"
:
q
=
search_space
[
key
][
"_value"
][
2
]
uv
=
ub
=
hyper_parameter
+
q
lv
=
lb
=
hyper_parameter
-
q
elif
search_space
[
key
][
"_type"
]
==
"lognormal"
:
sigma
=
search_space
[
key
][
"_value"
][
1
]
perturb
=
sigma
*
factor
uv
=
ub
=
np
.
exp
(
np
.
log
(
hyper_parameter
)
+
perturb
)
lv
=
lb
=
np
.
exp
(
np
.
log
(
hyper_parameter
)
-
perturb
)
elif
search_space
[
key
][
"_type"
]
==
"qlognormal"
:
q
=
search_space
[
key
][
"_value"
][
2
]
uv
=
ub
=
hyper_parameter
+
q
lv
,
lb
=
hyper_parameter
-
q
,
1E-10
else
:
else
:
logger
.
warning
(
"Illegal type to perturb: %s"
,
search_space
[
key
][
"_type"
])
continue
continue
if
search_space
[
key
][
"_type"
]
==
"choice"
:
idx
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
hyper_parameters
[
key
]
=
{
'_index'
:
idx
,
'_value'
:
choices
[
idx
]}
else
:
hyper_parameters
[
key
]
=
perturbation
(
search_space
[
key
][
"_type"
],
search_space
[
key
][
"_value"
],
resample_probability
,
uv
,
ub
,
lv
,
lb
,
random_state
)
bot_trial_info
.
hyper_parameters
=
hyper_parameters
bot_trial_info
.
hyper_parameters
=
hyper_parameters
bot_trial_info
.
clean_id
()
bot_trial_info
.
clean_id
()
...
@@ -70,7 +160,8 @@ class TrialInfo:
...
@@ -70,7 +160,8 @@ class TrialInfo:
class
PBTTuner
(
Tuner
):
class
PBTTuner
(
Tuner
):
def
__init__
(
self
,
optimize_mode
=
"maximize"
,
all_checkpoint_dir
=
None
,
population_size
=
10
,
factors
=
(
1.2
,
0.8
),
fraction
=
0.2
):
def
__init__
(
self
,
optimize_mode
=
"maximize"
,
all_checkpoint_dir
=
None
,
population_size
=
10
,
factor
=
0.2
,
resample_probability
=
0.25
,
fraction
=
0.2
):
"""
"""
Initialization
Initialization
...
@@ -82,8 +173,10 @@ class PBTTuner(Tuner):
...
@@ -82,8 +173,10 @@ class PBTTuner(Tuner):
directory to store training model checkpoint
directory to store training model checkpoint
population_size : int
population_size : int
number of trials for each epoch
number of trials for each epoch
factors : tuple
factor : float
factors for perturbation
factor for perturbation
resample_probability : float
probability for resampling
fraction : float
fraction : float
fraction for selecting bottom and top trials
fraction for selecting bottom and top trials
"""
"""
...
@@ -93,7 +186,8 @@ class PBTTuner(Tuner):
...
@@ -93,7 +186,8 @@ class PBTTuner(Tuner):
logger
.
info
(
"Checkpoint dir is set to %s by default."
,
all_checkpoint_dir
)
logger
.
info
(
"Checkpoint dir is set to %s by default."
,
all_checkpoint_dir
)
self
.
all_checkpoint_dir
=
all_checkpoint_dir
self
.
all_checkpoint_dir
=
all_checkpoint_dir
self
.
population_size
=
population_size
self
.
population_size
=
population_size
self
.
factors
=
factors
self
.
factor
=
factor
self
.
resample_probability
=
resample_probability
self
.
fraction
=
fraction
self
.
fraction
=
fraction
# defined in trial code
# defined in trial code
#self.perturbation_interval = perturbation_interval
#self.perturbation_interval = perturbation_interval
...
@@ -237,7 +331,7 @@ class PBTTuner(Tuner):
...
@@ -237,7 +331,7 @@ class PBTTuner(Tuner):
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
bottoms
=
self
.
finished
[
self
.
finished_trials
-
cutoff
:]
for
bottom
in
bottoms
:
for
bottom
in
bottoms
:
top
=
np
.
random
.
choice
(
tops
)
top
=
np
.
random
.
choice
(
tops
)
exploit_and_explore
(
bottom
,
top
,
self
.
factor
s
,
self
.
epoch
,
self
.
searchspace_json
)
exploit_and_explore
(
bottom
,
top
,
self
.
factor
,
self
.
resample_probability
,
self
.
epoch
,
self
.
searchspace_json
)
for
trial
in
self
.
finished
:
for
trial
in
self
.
finished
:
if
trial
not
in
bottoms
:
if
trial
not
in
bottoms
:
trial
.
clean_id
()
trial
.
clean_id
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment