Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c6b7cc89
Commit
c6b7cc89
authored
Mar 12, 2019
by
Shufan Huang
Committed by
xuehui
Mar 12, 2019
Browse files
Solve bug caused by scientific calculation errors (#828)
* add epsilon * add epsilon for ceil
parent
8f71e99f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+5
-4
No files found.
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
c6b7cc89
...
@@ -38,6 +38,7 @@ _logger = logging.getLogger(__name__)
...
@@ -38,6 +38,7 @@ _logger = logging.getLogger(__name__)
_next_parameter_id
=
0
_next_parameter_id
=
0
_KEY
=
'STEPS'
_KEY
=
'STEPS'
_epsilon
=
1e-6
@
unique
@
unique
class
OptimizeMode
(
Enum
):
class
OptimizeMode
(
Enum
):
...
@@ -141,8 +142,8 @@ class Bracket():
...
@@ -141,8 +142,8 @@ class Bracket():
self
.
bracket_id
=
s
self
.
bracket_id
=
s
self
.
s_max
=
s_max
self
.
s_max
=
s_max
self
.
eta
=
eta
self
.
eta
=
eta
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
(
eta
**
s
)
/
(
s
+
1
))
# pylint: disable=invalid-name
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
(
eta
**
s
)
/
(
s
+
1
)
-
_epsilon
)
# pylint: disable=invalid-name
self
.
r
=
math
.
ceil
(
R
/
eta
**
s
)
# pylint: disable=invalid-name
self
.
r
=
math
.
ceil
(
R
/
eta
**
s
-
_epsilon
)
# pylint: disable=invalid-name
self
.
i
=
0
self
.
i
=
0
self
.
hyper_configs
=
[]
# [ {id: params}, {}, ... ]
self
.
hyper_configs
=
[]
# [ {id: params}, {}, ... ]
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
...
@@ -157,7 +158,7 @@ class Bracket():
...
@@ -157,7 +158,7 @@ class Bracket():
def
get_n_r
(
self
):
def
get_n_r
(
self
):
"""return the values of n and r for the next round"""
"""return the values of n and r for the next round"""
return
math
.
floor
(
self
.
n
/
self
.
eta
**
self
.
i
),
self
.
r
*
self
.
eta
**
self
.
i
return
math
.
floor
(
self
.
n
/
self
.
eta
**
self
.
i
+
_epsilon
),
self
.
r
*
self
.
eta
**
self
.
i
def
increase_i
(
self
):
def
increase_i
(
self
):
"""i means the ith round. Increase i by 1"""
"""i means the ith round. Increase i by 1"""
...
@@ -305,7 +306,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -305,7 +306,7 @@ class Hyperband(MsgDispatcherBase):
self
.
brackets
=
dict
()
# dict of Bracket
self
.
brackets
=
dict
()
# dict of Bracket
self
.
generated_hyper_configs
=
[]
# all the configs waiting for run
self
.
generated_hyper_configs
=
[]
# all the configs waiting for run
self
.
completed_hyper_configs
=
[]
# all the completed configs
self
.
completed_hyper_configs
=
[]
# all the completed configs
self
.
s_max
=
math
.
floor
(
math
.
log
(
self
.
R
,
self
.
eta
))
self
.
s_max
=
math
.
floor
(
math
.
log
(
self
.
R
,
self
.
eta
)
+
_epsilon
)
self
.
curr_s
=
self
.
s_max
self
.
curr_s
=
self
.
s_max
self
.
searchspace_json
=
None
self
.
searchspace_json
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment