Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9fb25ccc
Unverified
Commit
9fb25ccc
authored
Jul 17, 2019
by
SparkSnail
Committed by
GitHub
Jul 17, 2019
Browse files
Merge pull request #189 from microsoft/master
merge master
parents
1500458a
7c4bc33b
Changes
180
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
890 additions
and
237 deletions
+890
-237
examples/trials/mnist/config_windows.yml
examples/trials/mnist/config_windows.yml
+1
-1
examples/trials/network_morphism/README_zh_CN.md
examples/trials/network_morphism/README_zh_CN.md
+1
-1
examples/tuners/ga_customer_tuner/README_zh_CN.md
examples/tuners/ga_customer_tuner/README_zh_CN.md
+1
-1
examples/tuners/weight_sharing/ga_customer_tuner/README_zh_CN.md
...s/tuners/weight_sharing/ga_customer_tuner/README_zh_CN.md
+1
-1
src/nni_manager/package.json
src/nni_manager/package.json
+4
-0
src/nni_manager/rest_server/restValidationSchemas.ts
src/nni_manager/rest_server/restValidationSchemas.ts
+2
-1
src/nni_manager/yarn.lock
src/nni_manager/yarn.lock
+70
-87
src/sdk/pynni/nni/__init__.py
src/sdk/pynni/nni/__init__.py
+1
-0
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+0
-2
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+131
-70
src/sdk/pynni/nni/constants.py
src/sdk/pynni/nni/constants.py
+3
-1
src/sdk/pynni/nni/curvefitting_assessor/test.py
src/sdk/pynni/nni/curvefitting_assessor/test.py
+2
-0
src/sdk/pynni/nni/gp_tuner/__init__.py
src/sdk/pynni/nni/gp_tuner/__init__.py
+0
-0
src/sdk/pynni/nni/gp_tuner/gp_tuner.py
src/sdk/pynni/nni/gp_tuner/gp_tuner.py
+170
-0
src/sdk/pynni/nni/gp_tuner/target_space.py
src/sdk/pynni/nni/gp_tuner/target_space.py
+219
-0
src/sdk/pynni/nni/gp_tuner/util.py
src/sdk/pynni/nni/gp_tuner/util.py
+172
-0
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+61
-39
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
+5
-4
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+40
-24
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+6
-5
No files found.
examples/trials/mnist/config_windows.yml
View file @
9fb25ccc
...
@@ -9,7 +9,7 @@ searchSpacePath: search_space.json
...
@@ -9,7 +9,7 @@ searchSpacePath: search_space.json
#choice: true, false
#choice: true, false
useAnnotation
:
false
useAnnotation
:
false
tuner
:
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
, GPTuner
#SMAC (SMAC should be installed through nnictl)
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
builtinTunerName
:
TPE
classArgs
:
classArgs
:
...
...
examples/trials/network_morphism/README_zh_CN.md
View file @
9fb25ccc
...
@@ -32,7 +32,7 @@ trainingServicePlatform: local
...
@@ -32,7 +32,7 @@ trainingServicePlatform: local
useAnnotation
:
false
useAnnotation
:
false
tuner
:
tuner
:
#可选项: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#可选项: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC 需要通过 nnictl 安装)
#SMAC (SMAC 需要通过 nnictl 安装)
builtinTunerName
:
NetworkMorphism
builtinTunerName
:
NetworkMorphism
classArgs
:
classArgs
:
#可选项: maximize, minimize
#可选项: maximize, minimize
...
...
examples/tuners/ga_customer_tuner/README_zh_CN.md
View file @
9fb25ccc
# 如何使用 ga_customer_tuner?
# 如何使用 ga_customer_tuner?
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",
输入
`cd ~/nni/examples/trials/ga_squad`
查看 readme.md 来了解 ga_squad 的更多信息。
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",输入
`cd ~/nni/examples/trials/ga_squad`
查看 readme.md 来了解 ga_squad 的更多信息。
# 配置
# 配置
...
...
examples/tuners/weight_sharing/ga_customer_tuner/README_zh_CN.md
View file @
9fb25ccc
# 如何使用 ga_customer_tuner?
# 如何使用 ga_customer_tuner?
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",
输入
`cd ~/nni/examples/trials/ga_squad`
查看 readme.md 来了解 ga_squad 的更多信息。
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",输入
`cd ~/nni/examples/trials/ga_squad`
查看 readme.md 来了解 ga_squad 的更多信息。
# 配置
# 配置
...
...
src/nni_manager/package.json
View file @
9fb25ccc
...
@@ -54,6 +54,10 @@
...
@@ -54,6 +54,10 @@
"tslint-microsoft-contrib"
:
"^6.0.0"
,
"tslint-microsoft-contrib"
:
"^6.0.0"
,
"typescript"
:
"^3.2.2"
"typescript"
:
"^3.2.2"
},
},
"resolutions"
:
{
"mem"
:
"^4.0.0"
,
"handlebars"
:
"^4.1.0"
},
"engines"
:
{
"engines"
:
{
"node"
:
">=10.0.0"
"node"
:
">=10.0.0"
},
},
...
...
src/nni_manager/rest_server/restValidationSchemas.ts
View file @
9fb25ccc
...
@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
...
@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
command
:
joi
.
string
().
min
(
1
),
command
:
joi
.
string
().
min
(
1
),
virtualCluster
:
joi
.
string
(),
virtualCluster
:
joi
.
string
(),
shmMB
:
joi
.
number
(),
shmMB
:
joi
.
number
(),
nasMode
:
joi
.
string
().
valid
(
'
classic_mode
'
,
'
enas_mode
'
,
'
oneshot_mode
'
),
worker
:
joi
.
object
({
worker
:
joi
.
object
({
replicas
:
joi
.
number
().
min
(
1
).
required
(),
replicas
:
joi
.
number
().
min
(
1
).
required
(),
image
:
joi
.
string
().
min
(
1
),
image
:
joi
.
string
().
min
(
1
),
...
@@ -161,7 +162,7 @@ export namespace ValidationSchemas {
...
@@ -161,7 +162,7 @@ export namespace ValidationSchemas {
checkpointDir
:
joi
.
string
().
allow
(
''
)
checkpointDir
:
joi
.
string
().
allow
(
''
)
}),
}),
tuner
:
joi
.
object
({
tuner
:
joi
.
object
({
builtinTunerName
:
joi
.
string
().
valid
(
'
TPE
'
,
'
Random
'
,
'
Anneal
'
,
'
Evolution
'
,
'
SMAC
'
,
'
BatchTuner
'
,
'
GridSearch
'
,
'
NetworkMorphism
'
,
'
MetisTuner
'
),
builtinTunerName
:
joi
.
string
().
valid
(
'
TPE
'
,
'
Random
'
,
'
Anneal
'
,
'
Evolution
'
,
'
SMAC
'
,
'
BatchTuner
'
,
'
GridSearch
'
,
'
NetworkMorphism
'
,
'
MetisTuner
'
,
'
GPTuner
'
),
codeDir
:
joi
.
string
(),
codeDir
:
joi
.
string
(),
classFileName
:
joi
.
string
(),
classFileName
:
joi
.
string
(),
className
:
joi
.
string
(),
className
:
joi
.
string
(),
...
...
src/nni_manager/yarn.lock
View file @
9fb25ccc
...
@@ -145,6 +145,10 @@
...
@@ -145,6 +145,10 @@
"@types/minimatch" "*"
"@types/minimatch" "*"
"@types/node" "*"
"@types/node" "*"
"@types/js-base64@^2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@types/js-base64/-/js-base64-2.3.1.tgz#c39f14f129408a3d96a1105a650d8b2b6eeb4168"
"@types/mime@*":
"@types/mime@*":
version "2.0.0"
version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
...
@@ -161,9 +165,9 @@
...
@@ -161,9 +165,9 @@
version "10.5.2"
version "10.5.2"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.5.2.tgz#f19f05314d5421fe37e74153254201a7bf00a707"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.5.2.tgz#f19f05314d5421fe37e74153254201a7bf00a707"
"@types/node@
^
10.
5.5
":
"@types/node@10.
12.18
":
version "10.
5.5
"
version "10.
12.18
"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.
5.5.tgz#8e84d24e896cd77b0d4f73df274027e3149ec2ba
"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.
12.18.tgz#1d3ca764718915584fcd9f6344621b7672665c67
"
"@types/range-parser@*":
"@types/range-parser@*":
version "1.2.2"
version "1.2.2"
...
@@ -342,10 +346,6 @@ ansi-regex@^3.0.0:
...
@@ -342,10 +346,6 @@ ansi-regex@^3.0.0:
version "3.0.0"
version "3.0.0"
resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-3.0.0.tgz#ed0317c322064f79466c02966bddb605ab37d998"
resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-3.0.0.tgz#ed0317c322064f79466c02966bddb605ab37d998"
ansi-styles@^2.2.1:
version "2.2.1"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-2.2.1.tgz#b432dd3358b634cf75e1e4664368240533c1ddbe"
ansi-styles@^3.2.1:
ansi-styles@^3.2.1:
version "3.2.1"
version "3.2.1"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d"
...
@@ -413,12 +413,6 @@ async-limiter@~1.0.0:
...
@@ -413,12 +413,6 @@ async-limiter@~1.0.0:
version "1.0.0"
version "1.0.0"
resolved "https://registry.yarnpkg.com/async-limiter/-/async-limiter-1.0.0.tgz#78faed8c3d074ab81f22b4e985d79e8738f720f8"
resolved "https://registry.yarnpkg.com/async-limiter/-/async-limiter-1.0.0.tgz#78faed8c3d074ab81f22b4e985d79e8738f720f8"
async@^2.5.0:
version "2.6.1"
resolved "https://registry.yarnpkg.com/async/-/async-2.6.1.tgz#b245a23ca71930044ec53fa46aa00a3e87c6a610"
dependencies:
lodash "^4.17.10"
asynckit@^0.4.0:
asynckit@^0.4.0:
version "0.4.0"
version "0.4.0"
resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
...
@@ -451,14 +445,6 @@ azure-storage@^2.10.2:
...
@@ -451,14 +445,6 @@ azure-storage@^2.10.2:
xml2js "0.2.8"
xml2js "0.2.8"
xmlbuilder "^9.0.7"
xmlbuilder "^9.0.7"
babel-code-frame@^6.22.0:
version "6.26.0"
resolved "https://registry.yarnpkg.com/babel-code-frame/-/babel-code-frame-6.26.0.tgz#63fd43f7dc1e3bb7ce35947db8fe369a3f58c74b"
dependencies:
chalk "^1.1.3"
esutils "^2.0.2"
js-tokens "^3.0.2"
balanced-match@^1.0.0:
balanced-match@^1.0.0:
version "1.0.0"
version "1.0.0"
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.0.tgz#89b4d199ab2bee49de164ea02b89ce462d71b767"
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.0.tgz#89b4d199ab2bee49de164ea02b89ce462d71b767"
...
@@ -575,16 +561,6 @@ chai@^4.1.2:
...
@@ -575,16 +561,6 @@ chai@^4.1.2:
pathval "^1.0.0"
pathval "^1.0.0"
type-detect "^4.0.0"
type-detect "^4.0.0"
chalk@^1.1.3:
version "1.1.3"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-1.1.3.tgz#a8115c55e4a702fe4d150abd3872822a7e09fc98"
dependencies:
ansi-styles "^2.2.1"
escape-string-regexp "^1.0.2"
has-ansi "^2.0.0"
strip-ansi "^3.0.0"
supports-color "^2.0.0"
chalk@^2.0.0, chalk@^2.3.0:
chalk@^2.0.0, chalk@^2.3.0:
version "2.4.1"
version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
...
@@ -854,7 +830,7 @@ escape-html@~1.0.3:
...
@@ -854,7 +830,7 @@ escape-html@~1.0.3:
version "1.0.3"
version "1.0.3"
resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
escape-string-regexp@1.0.5,
escape-string-regexp@^1.0.2,
escape-string-regexp@^1.0.5:
escape-string-regexp@1.0.5, escape-string-regexp@^1.0.5:
version "1.0.5"
version "1.0.5"
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4"
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4"
...
@@ -1135,11 +1111,11 @@ growl@1.10.5:
...
@@ -1135,11 +1111,11 @@ growl@1.10.5:
version "1.10.5"
version "1.10.5"
resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e"
resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e"
handlebars@^4.0.11:
handlebars@^4.0.11
, handlebars@^4.1.0
:
version "4.
0.
12"
version "4.1
.
2"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.
0.
12.tgz#
2c15c8a96d46da5e266700518ba8cb8d919d5bc5
"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.1
.
2.tgz#
b6b37c1ced0306b221e094fc7aca3ec23b131b67
"
dependencies:
dependencies:
async "^2.
5
.0"
neo-
async "^2.
6
.0"
optimist "^0.6.1"
optimist "^0.6.1"
source-map "^0.6.1"
source-map "^0.6.1"
optionalDependencies:
optionalDependencies:
...
@@ -1163,12 +1139,6 @@ har-validator@~5.1.0:
...
@@ -1163,12 +1139,6 @@ har-validator@~5.1.0:
ajv "^5.3.0"
ajv "^5.3.0"
har-schema "^2.0.0"
har-schema "^2.0.0"
has-ansi@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/has-ansi/-/has-ansi-2.0.0.tgz#34f5049ce1ecdf2b0649af3ef24e45ed35416d91"
dependencies:
ansi-regex "^2.0.0"
has-flag@^3.0.0:
has-flag@^3.0.0:
version "3.0.0"
version "3.0.0"
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd"
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd"
...
@@ -1426,21 +1396,24 @@ js-base64@^2.4.9:
...
@@ -1426,21 +1396,24 @@ js-base64@^2.4.9:
version "2.5.0"
version "2.5.0"
resolved "https://registry.yarnpkg.com/js-base64/-/js-base64-2.5.0.tgz#42255ba183ab67ce59a0dee640afdc00ab5ae93e"
resolved "https://registry.yarnpkg.com/js-base64/-/js-base64-2.5.0.tgz#42255ba183ab67ce59a0dee640afdc00ab5ae93e"
js-tokens@^3.0.2:
version "3.0.2"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-3.0.2.tgz#9866df395102130e38f7f996bceb65443209c25b"
js-tokens@^4.0.0:
js-tokens@^4.0.0:
version "4.0.0"
version "4.0.0"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
js-yaml@^3.10.0
, js-yaml@^3.7.0
:
js-yaml@^3.10.0:
version "3.12.0"
version "3.12.0"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.12.0.tgz#eaed656ec8344f10f527c6bfa1b6e2244de167d1"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.12.0.tgz#eaed656ec8344f10f527c6bfa1b6e2244de167d1"
dependencies:
dependencies:
argparse "^1.0.7"
argparse "^1.0.7"
esprima "^4.0.0"
esprima "^4.0.0"
js-yaml@^3.13.1:
version "3.13.1"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.13.1.tgz#aff151b30bfdfa8e49e05da22e7415e9dfa37847"
dependencies:
argparse "^1.0.7"
esprima "^4.0.0"
jsbn@~0.1.0:
jsbn@~0.1.0:
version "0.1.1"
version "0.1.1"
resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513"
resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513"
...
@@ -1619,6 +1592,12 @@ make-error@^1.1.1:
...
@@ -1619,6 +1592,12 @@ make-error@^1.1.1:
version "1.3.4"
version "1.3.4"
resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.4.tgz#19978ed575f9e9545d2ff8c13e33b5d18a67d535"
resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.4.tgz#19978ed575f9e9545d2ff8c13e33b5d18a67d535"
map-age-cleaner@^0.1.1:
version "0.1.3"
resolved "https://registry.yarnpkg.com/map-age-cleaner/-/map-age-cleaner-0.1.3.tgz#7d583a7306434c055fe474b0f45078e6e1b4b92a"
dependencies:
p-defer "^1.0.0"
md5-hex@^2.0.0:
md5-hex@^2.0.0:
version "2.0.0"
version "2.0.0"
resolved "https://registry.yarnpkg.com/md5-hex/-/md5-hex-2.0.0.tgz#d0588e9f1c74954492ecd24ac0ac6ce997d92e33"
resolved "https://registry.yarnpkg.com/md5-hex/-/md5-hex-2.0.0.tgz#d0588e9f1c74954492ecd24ac0ac6ce997d92e33"
...
@@ -1640,11 +1619,13 @@ media-typer@0.3.0:
...
@@ -1640,11 +1619,13 @@ media-typer@0.3.0:
version "0.3.0"
version "0.3.0"
resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
mem@^1.1.0:
mem@^1.1.0
, mem@^4.0.0
:
version "
1.1
.0"
version "
4.3
.0"
resolved "https://registry.yarnpkg.com/mem/-/mem-
1.1
.0.tgz#
5edd52b485ca1d900fe64895505399a0dfa45f76
"
resolved "https://registry.yarnpkg.com/mem/-/mem-
4.3
.0.tgz#
461af497bc4ae09608cdb2e60eefb69bff744178
"
dependencies:
dependencies:
mimic-fn "^1.0.0"
map-age-cleaner "^0.1.1"
mimic-fn "^2.0.0"
p-is-promise "^2.0.0"
merge-descriptors@1.0.1:
merge-descriptors@1.0.1:
version "1.0.1"
version "1.0.1"
...
@@ -1684,9 +1665,9 @@ mime@1.4.1:
...
@@ -1684,9 +1665,9 @@ mime@1.4.1:
version "1.4.1"
version "1.4.1"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
mimic-fn@^
1
.0.0:
mimic-fn@^
2
.0.0:
version "
1.2
.0"
version "
2.1
.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-
1.2
.0.tgz#
820c86a39334640e99516928bd03fca88057d022
"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-
2.1
.0.tgz#
7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b
"
mimic-response@^1.0.0:
mimic-response@^1.0.0:
version "1.0.1"
version "1.0.1"
...
@@ -1773,6 +1754,10 @@ negotiator@0.6.1:
...
@@ -1773,6 +1754,10 @@ negotiator@0.6.1:
version "0.6.1"
version "0.6.1"
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.1.tgz#2b327184e8992101177b28563fb5e7102acd0ca9"
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.1.tgz#2b327184e8992101177b28563fb5e7102acd0ca9"
neo-async@^2.6.0:
version "2.6.1"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.1.tgz#ac27ada66167fa8849a6addd837f6b189ad2081c"
node-forge@^0.7.6:
node-forge@^0.7.6:
version "0.7.6"
version "0.7.6"
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac"
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac"
...
@@ -1797,12 +1782,6 @@ node-jose@^1.1.0:
...
@@ -1797,12 +1782,6 @@ node-jose@^1.1.0:
node-forge "^0.7.6"
node-forge "^0.7.6"
uuid "^3.3.2"
uuid "^3.3.2"
node-nvidia-smi@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/node-nvidia-smi/-/node-nvidia-smi-1.0.0.tgz#6aa57574540b2bed91c9a80218516ffa686e5ac7"
dependencies:
xml2js "^0.4.17"
node-pre-gyp@^0.10.3:
node-pre-gyp@^0.10.3:
version "0.10.3"
version "0.10.3"
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc"
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc"
...
@@ -2005,6 +1984,10 @@ p-cancelable@^0.4.0:
...
@@ -2005,6 +1984,10 @@ p-cancelable@^0.4.0:
version "0.4.1"
version "0.4.1"
resolved "http://registry.npmjs.org/p-cancelable/-/p-cancelable-0.4.1.tgz#35f363d67d52081c8d9585e37bcceb7e0bbcb2a0"
resolved "http://registry.npmjs.org/p-cancelable/-/p-cancelable-0.4.1.tgz#35f363d67d52081c8d9585e37bcceb7e0bbcb2a0"
p-defer@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-defer/-/p-defer-1.0.0.tgz#9f6eb182f6c9aa8cd743004a7d4f96b196b0fb0c"
p-finally@^1.0.0:
p-finally@^1.0.0:
version "1.0.0"
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae"
resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae"
...
@@ -2013,6 +1996,10 @@ p-is-promise@^1.1.0:
...
@@ -2013,6 +1996,10 @@ p-is-promise@^1.1.0:
version "1.1.0"
version "1.1.0"
resolved "http://registry.npmjs.org/p-is-promise/-/p-is-promise-1.1.0.tgz#9c9456989e9f6588017b0434d56097675c3da05e"
resolved "http://registry.npmjs.org/p-is-promise/-/p-is-promise-1.1.0.tgz#9c9456989e9f6588017b0434d56097675c3da05e"
p-is-promise@^2.0.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/p-is-promise/-/p-is-promise-2.1.0.tgz#918cebaea248a62cf7ffab8e3bca8c5f882fc42e"
p-limit@^1.1.0:
p-limit@^1.1.0:
version "1.3.0"
version "1.3.0"
resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8"
resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8"
...
@@ -2384,7 +2371,7 @@ sax@0.5.x:
...
@@ -2384,7 +2371,7 @@ sax@0.5.x:
version "0.5.8"
version "0.5.8"
resolved "http://registry.npmjs.org/sax/-/sax-0.5.8.tgz#d472db228eb331c2506b0e8c15524adb939d12c1"
resolved "http://registry.npmjs.org/sax/-/sax-0.5.8.tgz#d472db228eb331c2506b0e8c15524adb939d12c1"
sax@>=0.6.0,
sax@^1.2.4:
sax@^1.2.4:
version "1.2.4"
version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
...
@@ -2619,10 +2606,6 @@ supports-color@5.4.0, supports-color@^5.3.0:
...
@@ -2619,10 +2606,6 @@ supports-color@5.4.0, supports-color@^5.3.0:
dependencies:
dependencies:
has-flag "^3.0.0"
has-flag "^3.0.0"
supports-color@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-2.0.0.tgz#535d045ce6b6363fa40117084629995e9df324c7"
supports-color@^5.4.0:
supports-color@^5.4.0:
version "5.5.0"
version "5.5.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f"
...
@@ -2716,30 +2699,37 @@ tslib@^1.8.0, tslib@^1.8.1:
...
@@ -2716,30 +2699,37 @@ tslib@^1.8.0, tslib@^1.8.1:
version "1.9.3"
version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
tslint-microsoft-contrib@^
5.1
.0:
tslint-microsoft-contrib@^
6.0
.0:
version "
5.1
.0"
version "
6.2
.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-
5.1
.0.tgz#
777c32d51aba16f4565e47aac749a1631176cd9f
"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-
6.2
.0.tgz#
8aa0f40584d066d05e6a5e7988da5163b85f2ad4
"
dependencies:
dependencies:
tsutils "^2.
12.1
"
tsutils "^2.
27.2 <2.29.0
"
tslint@^5.1
1
.0:
tslint@^5.1
2
.0:
version "5.1
1
.0"
version "5.1
8
.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.1
1
.0.tgz#
98f30c02eae3cde7006201e4c33cb08b48581eed
"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.1
8
.0.tgz#
f61a6ddcf372344ac5e41708095bbf043a147ac6
"
dependencies:
dependencies:
babel
-
code-frame "^
6.22
.0"
"@
babel
/
code-frame
"
"^
7.0
.0"
builtin-modules "^1.1.1"
builtin-modules "^1.1.1"
chalk "^2.3.0"
chalk "^2.3.0"
commander "^2.12.1"
commander "^2.12.1"
diff "^3.2.0"
diff "^3.2.0"
glob "^7.1.1"
glob "^7.1.1"
js-yaml "^3.
7.0
"
js-yaml "^3.
13.1
"
minimatch "^3.0.4"
minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2"
resolve "^1.3.2"
semver "^5.3.0"
semver "^5.3.0"
tslib "^1.8.0"
tslib "^1.8.0"
tsutils "^2.27.2"
tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.
12.1, tsutils@^2.27.2
:
tsutils@^2.
29.0
:
version "2.29.0"
version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies:
dependencies:
...
@@ -2777,9 +2767,9 @@ typescript-string-operations@^1.3.1:
...
@@ -2777,9 +2767,9 @@ typescript-string-operations@^1.3.1:
version "1.3.1"
version "1.3.1"
resolved "https://registry.yarnpkg.com/typescript-string-operations/-/typescript-string-operations-1.3.1.tgz#461b886cc9ccd4dd16810b1f248b2e6f6580956b"
resolved "https://registry.yarnpkg.com/typescript-string-operations/-/typescript-string-operations-1.3.1.tgz#461b886cc9ccd4dd16810b1f248b2e6f6580956b"
typescript@^3.
0.1
:
typescript@^3.
2.2
:
version "3.
0.1
"
version "3.
5.2
"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.
0.1
.tgz#
43738f29585d3a87575520a4b93ab6026ef11fdb
"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.
5.2
.tgz#
a09e1dc69bc9551cadf17dba10ee42cf55e5d56c
"
uglify-js@^3.1.4:
uglify-js@^3.1.4:
version "3.4.9"
version "3.4.9"
...
@@ -2900,14 +2890,7 @@ xml2js@0.2.8:
...
@@ -2900,14 +2890,7 @@ xml2js@0.2.8:
dependencies:
dependencies:
sax "0.5.x"
sax "0.5.x"
xml2js@^0.4.17:
xmlbuilder@^9.0.7:
version "0.4.19"
resolved "https://registry.yarnpkg.com/xml2js/-/xml2js-0.4.19.tgz#686c20f213209e94abf0d1bcf1efaa291c7827a7"
dependencies:
sax ">=0.6.0"
xmlbuilder "~9.0.1"
xmlbuilder@^9.0.7, xmlbuilder@~9.0.1:
version "9.0.7"
version "9.0.7"
resolved "http://registry.npmjs.org/xmlbuilder/-/xmlbuilder-9.0.7.tgz#132ee63d2ec5565c557e20f4c22df9aca686b10d"
resolved "http://registry.npmjs.org/xmlbuilder/-/xmlbuilder-9.0.7.tgz#132ee63d2ec5565c557e20f4c22df9aca686b10d"
...
...
src/sdk/pynni/nni/__init__.py
View file @
9fb25ccc
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
from
.trial
import
*
from
.trial
import
*
from
.smartparam
import
*
from
.smartparam
import
*
from
.nas_utils
import
reload_tensorflow_variables
class
NoMoreTrialError
(
Exception
):
class
NoMoreTrialError
(
Exception
):
def
__init__
(
self
,
ErrorInfo
):
def
__init__
(
self
,
ErrorInfo
):
...
...
src/sdk/pynni/nni/__main__.py
View file @
9fb25ccc
...
@@ -130,8 +130,6 @@ def main():
...
@@ -130,8 +130,6 @@ def main():
if
args
.
advisor_class_name
:
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
# advisor is enabled and starts to run
if
args
.
multi_phase
:
raise
AssertionError
(
'multi_phase has not been supported in advisor'
)
if
args
.
advisor_class_name
in
AdvisorModuleName
:
if
args
.
advisor_class_name
in
AdvisorModuleName
:
dispatcher
=
create_builtin_class_instance
(
dispatcher
=
create_builtin_class_instance
(
args
.
advisor_class_name
,
args
.
advisor_class_name
,
...
...
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
9fb25ccc
...
@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
...
@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
from
nni.protocol
import
CommandType
,
send
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.utils
import
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
from
nni.common
import
multi_phase_enabled
from
.config_generator
import
CG_BOHB
from
.config_generator
import
CG_BOHB
...
@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
...
@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return
params_id
return
params_id
class
Bracket
():
class
Bracket
(
object
):
"""
"""
A bracket in BOHB, all the information of a bracket is managed by
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
an instance of this class.
...
@@ -106,7 +107,7 @@ class Bracket():
...
@@ -106,7 +107,7 @@ class Bracket():
self
.
s_max
=
s_max
self
.
s_max
=
s_max
self
.
eta
=
eta
self
.
eta
=
eta
self
.
max_budget
=
max_budget
self
.
max_budget
=
max_budget
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
eta
**
s
/
(
s
+
1
)
-
_epsilon
)
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
eta
**
s
/
(
s
+
1
)
-
_epsilon
)
self
.
r
=
max_budget
/
eta
**
s
self
.
r
=
max_budget
/
eta
**
s
...
@@ -259,33 +260,34 @@ class BOHB(MsgDispatcherBase):
...
@@ -259,33 +260,34 @@ class BOHB(MsgDispatcherBase):
optimize_mode: str
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
optimize mode, 'maximize' or 'minimize'
min_budget: float
min_budget: float
The smallest budget to consider. Needs to be positive!
The smallest budget to consider. Needs to be positive!
max_budget: float
max_budget: float
The largest budget to consider. Needs to be larger than min_budget!
The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed
The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
:math:`a^2 + b^2 = c^2
\
\
sim
\
\
eta^k` for :math:`k
\
\
in [0, 1, ... , num
\
\
_subsets - 1]`.
eta: int
eta: int
In each iteration, a complete run of sequential halving is executed. In it,
In each iteration, a complete run of sequential halving is executed. In it,
after evaluating each configuration on the same subset size, only a fraction of
after evaluating each configuration on the same subset size, only a fraction of
1/eta of them 'advances' to the next round.
1/eta of them 'advances' to the next round.
Must be greater or equal to 2.
Must be greater or equal to 2.
min_points_in_model: int
min_points_in_model: int
number of observations to start building a KDE. Default 'None' means
number of observations to start building a KDE. Default 'None' means
dim+1, the bare minimum.
dim+1, the bare minimum.
top_n_percent: int
top_n_percent: int
percentage ( between 1 and 99, default 15) of the observations that are considered good.
percentage ( between 1 and 99, default 15) of the observations that are considered good.
num_samples: int
num_samples: int
number of samples to optimize EI (default 64)
number of samples to optimize EI (default 64)
random_fraction: float
random_fraction: float
fraction of purely random configurations that are sampled from the
fraction of purely random configurations that are sampled from the
prior without the model.
prior without the model.
bandwidth_factor: float
bandwidth_factor: float
to encourage diversity, the points proposed to optimize EI, are sampled
to encourage diversity, the points proposed to optimize EI, are sampled
from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3)
from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3)
min_bandwidth: float
min_bandwidth: float
to keep diversity, even when all (good) samples have the same value for one of the parameters,
to keep diversity, even when all (good) samples have the same value for one of the parameters,
a minimum bandwidth (Default: 1e-3) is used instead of zero.
a minimum bandwidth (Default: 1e-3) is used instead of zero.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
optimize_mode
=
'maximize'
,
optimize_mode
=
'maximize'
,
min_budget
=
1
,
min_budget
=
1
,
...
@@ -328,11 +330,12 @@ class BOHB(MsgDispatcherBase):
...
@@ -328,11 +330,12 @@ class BOHB(MsgDispatcherBase):
# config generator
# config generator
self
.
cg
=
None
self
.
cg
=
None
def
load_checkpoint
(
self
):
# record the latest parameter_id of the trial job trial_job_id.
pass
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def
save_checkpoint
(
self
):
self
.
job_id_para_id_map
=
dict
()
pass
# record the unsatisfied parameter request from trial jobs
self
.
unsatisfied_jobs
=
[]
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
...
@@ -398,7 +401,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -398,7 +401,7 @@ class BOHB(MsgDispatcherBase):
for
_
in
range
(
self
.
credit
):
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
self
.
_request_one_trial_job
()
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
If this function is called, Command will be sent by BOHB:
...
@@ -422,7 +425,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -422,7 +425,7 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
''
'parameters'
:
''
}
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
return
return
None
assert
self
.
generated_hyper_configs
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
()
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
ret
=
{
...
@@ -431,8 +434,29 @@ class BOHB(MsgDispatcherBase):
...
@@ -431,8 +434,29 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
params
[
1
]
'parameters'
:
params
[
1
]
}
}
self
.
parameters
[
params
[
0
]]
=
params
[
1
]
self
.
parameters
[
params
[
0
]]
=
params
[
1
]
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
ret
self
.
credit
-=
1
def
_request_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{
'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm'
'parameters': value of new hyperparameter
}
b. If BOHB don't have parameter waiting, will return "NoMoreTrialJobs" with
{
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
"""
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""change json format to ConfigSpace format dict<dict> -> configspace
"""change json format to ConfigSpace format dict<dict> -> configspace
...
@@ -501,23 +525,38 @@ class BOHB(MsgDispatcherBase):
...
@@ -501,23 +525,38 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
s
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
def
_send_new_trial
(
self
):
while
self
.
unsatisfied_jobs
:
ret
=
self
.
_get_one_trial_job
()
if
ret
is
None
:
break
one_unsatisfied
=
self
.
unsatisfied_jobs
.
pop
(
0
)
ret
[
'trial_job_id'
]
=
one_unsatisfied
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
_handle_trial_end
(
self
,
parameter_id
):
s
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
s
)].
inform_trial_end
(
int
(
i
))
hyper_configs
=
self
.
brackets
[
int
(
s
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
if
hyper_configs
is
not
None
:
logger
.
debug
(
logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
s
,
i
,
hyper_configs
)
'bracket %s next round %s, hyper_configs: %s'
,
s
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
# Finish this bracket and generate a new bracket
# Finish this bracket and generate a new bracket
elif
self
.
brackets
[
int
(
s
)].
no_more_trial
:
elif
self
.
brackets
[
int
(
s
)].
no_more_trial
:
self
.
curr_s
-=
1
self
.
curr_s
-=
1
self
.
generate_new_bracket
()
self
.
generate_new_bracket
()
for
_
in
range
(
self
.
credit
):
self
.
_send_new_trial
()
self
.
_request_one_trial_job
()
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""reveice the metric data and update Bayesian optimization with final result
"""reveice the metric data and update Bayesian optimization with final result
...
@@ -534,36 +573,58 @@ class BOHB(MsgDispatcherBase):
...
@@ -534,36 +573,58 @@ class BOHB(MsgDispatcherBase):
"""
"""
logger
.
debug
(
'handle report metric data = %s'
,
data
)
logger
.
debug
(
'handle report metric data = %s'
,
data
)
assert
'value'
in
data
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
value
=
extract_scalar_reward
(
data
[
'value'
])
assert
multi_phase_enabled
()
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
assert
data
[
'trial_job_id'
]
is
not
None
reward
=
-
value
assert
data
[
'parameter_index'
]
is
not
None
else
:
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
reward
=
value
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
assert
'parameter_id'
in
data
ret
=
self
.
_get_one_trial_job
()
s
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
if
ret
is
None
:
self
.
unsatisfied_jobs
.
append
({
'trial_job_id'
:
data
[
'trial_job_id'
],
'parameter_index'
:
data
[
'parameter_index'
]})
logger
.
debug
(
'bracket id = %s, metrics value = %s, type = %s'
,
s
,
value
,
data
[
'type'
])
else
:
s
=
int
(
s
)
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
assert
'type'
in
data
# update parameter_id in self.job_id_para_id_map
if
data
[
'type'
]
==
'FINAL'
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
# and PERIODICAL metric are independent, thus, not comparable.
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
assert
'sequence'
in
data
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
data
[
'sequence'
],
parameters
=
_parameters
,
update_model
=
True
)
elif
data
[
'type'
]
==
'PERIODICAL'
:
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
else
:
raise
ValueError
(
assert
'value'
in
data
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
value
=
extract_scalar_reward
(
data
[
'value'
])
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
reward
=
-
value
else
:
reward
=
value
assert
'parameter_id'
in
data
s
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
logger
.
debug
(
'bracket id = %s, metrics value = %s, type = %s'
,
s
,
value
,
data
[
'type'
])
s
=
int
(
s
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
assert
'type'
in
data
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# and PERIODICAL metric are independent, thus, not comparable.
assert
'sequence'
in
data
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
data
[
'sequence'
],
parameters
=
_parameters
,
update_model
=
True
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
pass
...
...
src/sdk/pynni/nni/constants.py
View file @
9fb25ccc
...
@@ -29,7 +29,8 @@ ModuleName = {
...
@@ -29,7 +29,8 @@ ModuleName = {
'GridSearch'
:
'nni.gridsearch_tuner.gridsearch_tuner'
,
'GridSearch'
:
'nni.gridsearch_tuner.gridsearch_tuner'
,
'NetworkMorphism'
:
'nni.networkmorphism_tuner.networkmorphism_tuner'
,
'NetworkMorphism'
:
'nni.networkmorphism_tuner.networkmorphism_tuner'
,
'Curvefitting'
:
'nni.curvefitting_assessor.curvefitting_assessor'
,
'Curvefitting'
:
'nni.curvefitting_assessor.curvefitting_assessor'
,
'MetisTuner'
:
'nni.metis_tuner.metis_tuner'
'MetisTuner'
:
'nni.metis_tuner.metis_tuner'
,
'GPTuner'
:
'nni.gp_tuner.gp_tuner'
}
}
ClassName
=
{
ClassName
=
{
...
@@ -42,6 +43,7 @@ ClassName = {
...
@@ -42,6 +43,7 @@ ClassName = {
'GridSearch'
:
'GridSearchTuner'
,
'GridSearch'
:
'GridSearchTuner'
,
'NetworkMorphism'
:
'NetworkMorphismTuner'
,
'NetworkMorphism'
:
'NetworkMorphismTuner'
,
'MetisTuner'
:
'MetisTuner'
,
'MetisTuner'
:
'MetisTuner'
,
'GPTuner'
:
'GPTuner'
,
'Medianstop'
:
'MedianstopAssessor'
,
'Medianstop'
:
'MedianstopAssessor'
,
'Curvefitting'
:
'CurvefittingAssessor'
'Curvefitting'
:
'CurvefittingAssessor'
...
...
src/sdk/pynni/nni/curvefitting_assessor/test.py
View file @
9fb25ccc
...
@@ -15,9 +15,11 @@
...
@@ -15,9 +15,11 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
numpy
as
np
import
unittest
import
unittest
from
.curvefitting_assessor
import
CurvefittingAssessor
from
.curvefitting_assessor
import
CurvefittingAssessor
from
.model_factory
import
CurveModel
from
nni.assessor
import
AssessResult
from
nni.assessor
import
AssessResult
class
TestCurveFittingAssessor
(
unittest
.
TestCase
):
class
TestCurveFittingAssessor
(
unittest
.
TestCase
):
...
...
src/sdk/pynni/nni/gp_tuner/__init__.py
0 → 100644
View file @
9fb25ccc
src/sdk/pynni/nni/gp_tuner/gp_tuner.py
0 → 100644
View file @
9fb25ccc
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
gp_tuner.py
'''
import
warnings
import
logging
import
numpy
as
np
from
sklearn.gaussian_process.kernels
import
Matern
from
sklearn.gaussian_process
import
GaussianProcessRegressor
from
nni.tuner
import
Tuner
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
.target_space
import
TargetSpace
from
.util
import
UtilityFunction
,
acq_max
logger
=
logging
.
getLogger
(
"GP_Tuner_AutoML"
)
class
GPTuner
(
Tuner
):
'''
GPTuner
'''
def
__init__
(
self
,
optimize_mode
=
"maximize"
,
utility
=
'ei'
,
kappa
=
5
,
xi
=
0
,
nu
=
2.5
,
alpha
=
1e-6
,
cold_start_num
=
10
,
selection_num_warm_up
=
100000
,
selection_num_starting_points
=
250
):
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
# utility function related
self
.
utility
=
utility
self
.
kappa
=
kappa
self
.
xi
=
xi
# target space
self
.
_space
=
None
self
.
_random_state
=
np
.
random
.
RandomState
()
# nu, alpha are GPR related params
self
.
_gp
=
GaussianProcessRegressor
(
kernel
=
Matern
(
nu
=
nu
),
alpha
=
alpha
,
normalize_y
=
True
,
n_restarts_optimizer
=
25
,
random_state
=
self
.
_random_state
)
# num of random evaluations before GPR
self
.
_cold_start_num
=
cold_start_num
# params for acq_max
self
.
_selection_num_warm_up
=
selection_num_warm_up
self
.
_selection_num_starting_points
=
selection_num_starting_points
# num of imported data
self
.
supplement_data_num
=
0
def
update_search_space
(
self
,
search_space
):
"""Update the self.bounds and self.types by the search_space.json
Parameters
----------
search_space : dict
"""
self
.
_space
=
TargetSpace
(
search_space
,
self
.
_random_state
)
def
generate_parameters
(
self
,
parameter_id
):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
gp will first randomly generate some parameters.
Otherwise, choose the parameters by the Gussian Process Model
Parameters
----------
parameter_id : int
Returns
-------
result : dict
"""
if
self
.
_space
.
len
()
<
self
.
_cold_start_num
:
results
=
self
.
_space
.
random_sample
()
else
:
# Sklearn's GP throws a large number of warnings at times, but
# we don't really need to see them here.
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
self
.
_gp
.
fit
(
self
.
_space
.
params
,
self
.
_space
.
target
)
util
=
UtilityFunction
(
kind
=
self
.
utility
,
kappa
=
self
.
kappa
,
xi
=
self
.
xi
)
results
=
acq_max
(
f_acq
=
util
.
utility
,
gp
=
self
.
_gp
,
y_max
=
self
.
_space
.
target
.
max
(),
bounds
=
self
.
_space
.
bounds
,
space
=
self
.
_space
,
num_warmup
=
self
.
_selection_num_warm_up
,
num_starting_points
=
self
.
_selection_num_starting_points
)
results
=
self
.
_space
.
array_to_params
(
results
)
logger
.
info
(
"Generate paramageters:
\n
%s"
,
results
)
return
results
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
"""Tuner receive result from trial.
Parameters
----------
parameter_id : int
parameters : dict
value : dict/float
if value is dict, it should have "default" key.
"""
value
=
extract_scalar_reward
(
value
)
if
self
.
optimize_mode
==
OptimizeMode
.
Minimize
:
value
=
-
value
logger
.
info
(
"Received trial result."
)
logger
.
info
(
"value :%s"
,
value
)
logger
.
info
(
"parameter : %s"
,
parameters
)
self
.
_space
.
register
(
parameters
,
value
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
%
(
_completed_num
,
len
(
data
)))
_completed_num
+=
1
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
%
_value
)
continue
self
.
supplement_data_num
+=
1
_parameter_id
=
'_'
.
join
(
[
"ImportData"
,
str
(
self
.
supplement_data_num
)])
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
_params
,
value
=
_value
)
logger
.
info
(
"Successfully import data to GP tuner."
)
src/sdk/pynni/nni/gp_tuner/target_space.py
0 → 100644
View file @
9fb25ccc
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
target_space.py
'''
import
numpy
as
np
import
nni.parameter_expressions
as
parameter_expressions
def
_hashable
(
params
):
""" ensure that an point is hashable by a python dict """
return
tuple
(
map
(
float
,
params
))
class
TargetSpace
():
"""
Holds the param-space coordinates (X) and target values (Y)
"""
def
__init__
(
self
,
pbounds
,
random_state
=
None
):
"""
Parameters
----------
pbounds : dict
Dictionary with parameters names as keys and a tuple with minimum
and maximum values.
random_state : int, RandomState, or None
optionally specify a seed for a random number generator
"""
self
.
random_state
=
random_state
# Get the name of the parameters
self
.
_keys
=
sorted
(
pbounds
)
# Create an array with parameters bounds
self
.
_bounds
=
np
.
array
(
[
item
[
1
]
for
item
in
sorted
(
pbounds
.
items
(),
key
=
lambda
x
:
x
[
0
])]
)
# preallocated memory for X and Y points
self
.
_params
=
np
.
empty
(
shape
=
(
0
,
self
.
dim
))
self
.
_target
=
np
.
empty
(
shape
=
(
0
))
# keep track of unique points we have seen so far
self
.
_cache
=
{}
def
__contains__
(
self
,
params
):
'''
check if a parameter is already registered
'''
return
_hashable
(
params
)
in
self
.
_cache
def
len
(
self
):
'''
length of registered params and targets
'''
assert
len
(
self
.
_params
)
==
len
(
self
.
_target
)
return
len
(
self
.
_target
)
@
property
def
params
(
self
):
'''
params: numpy array
'''
return
self
.
_params
@
property
def
target
(
self
):
'''
target: numpy array
'''
return
self
.
_target
@
property
def
dim
(
self
):
'''
dim: int
length of keys
'''
return
len
(
self
.
_keys
)
@
property
def
keys
(
self
):
'''
keys: numpy array
'''
return
self
.
_keys
@
property
def
bounds
(
self
):
'''bounds'''
return
self
.
_bounds
def
params_to_array
(
self
,
params
):
''' dict to array '''
try
:
assert
set
(
params
)
==
set
(
self
.
keys
)
except
AssertionError
:
raise
ValueError
(
"Parameters' keys ({}) do "
.
format
(
sorted
(
params
))
+
"not match the expected set of keys ({})."
.
format
(
self
.
keys
)
)
return
np
.
asarray
([
params
[
key
]
for
key
in
self
.
keys
])
def
array_to_params
(
self
,
x
):
'''
array to dict
maintain int type if the paramters is defined as int in search_space.json
'''
try
:
assert
len
(
x
)
==
len
(
self
.
keys
)
except
AssertionError
:
raise
ValueError
(
"Size of array ({}) is different than the "
.
format
(
len
(
x
))
+
"expected number of parameters ({})."
.
format
(
self
.
dim
())
)
params
=
{}
for
i
,
_bound
in
enumerate
(
self
.
_bounds
):
if
_bound
[
'_type'
]
==
'choice'
and
all
(
isinstance
(
val
,
int
)
for
val
in
_bound
[
'_value'
]):
params
.
update
({
self
.
keys
[
i
]:
int
(
x
[
i
])})
elif
_bound
[
'_type'
]
in
[
'randint'
]:
params
.
update
({
self
.
keys
[
i
]:
int
(
x
[
i
])})
else
:
params
.
update
({
self
.
keys
[
i
]:
x
[
i
]})
return
params
def
register
(
self
,
params
,
target
):
"""
Append a point and its target value to the known data.
Parameters
----------
x : dict
y : float
target function value
"""
x
=
self
.
params_to_array
(
params
)
if
x
in
self
:
#raise KeyError('Data point {} is not unique'.format(x))
print
(
'Data point {} is not unique'
.
format
(
x
))
# Insert data into unique dictionary
self
.
_cache
[
_hashable
(
x
.
ravel
())]
=
target
self
.
_params
=
np
.
concatenate
([
self
.
_params
,
x
.
reshape
(
1
,
-
1
)])
self
.
_target
=
np
.
concatenate
([
self
.
_target
,
[
target
]])
def
random_sample
(
self
):
"""
Creates a random point within the bounds of the space.
"""
params
=
np
.
empty
(
self
.
dim
)
for
col
,
_bound
in
enumerate
(
self
.
_bounds
):
if
_bound
[
'_type'
]
==
'choice'
:
params
[
col
]
=
parameter_expressions
.
choice
(
_bound
[
'_value'
],
self
.
random_state
)
elif
_bound
[
'_type'
]
==
'randint'
:
params
[
col
]
=
self
.
random_state
.
randint
(
_bound
[
'_value'
][
0
],
_bound
[
'_value'
][
1
],
size
=
1
)
elif
_bound
[
'_type'
]
==
'uniform'
:
params
[
col
]
=
parameter_expressions
.
uniform
(
_bound
[
'_value'
][
0
],
_bound
[
'_value'
][
1
],
self
.
random_state
)
elif
_bound
[
'_type'
]
==
'quniform'
:
params
[
col
]
=
parameter_expressions
.
quniform
(
_bound
[
'_value'
][
0
],
_bound
[
'_value'
][
1
],
_bound
[
'_value'
][
2
],
self
.
random_state
)
elif
_bound
[
'_type'
]
==
'loguniform'
:
params
[
col
]
=
parameter_expressions
.
loguniform
(
_bound
[
'_value'
][
0
],
_bound
[
'_value'
][
1
],
self
.
random_state
)
elif
_bound
[
'_type'
]
==
'qloguniform'
:
params
[
col
]
=
parameter_expressions
.
qloguniform
(
_bound
[
'_value'
][
0
],
_bound
[
'_value'
][
1
],
_bound
[
'_value'
][
2
],
self
.
random_state
)
return
params
def
max
(
self
):
"""Get maximum target value found and corresponding parametes."""
try
:
res
=
{
'target'
:
self
.
target
.
max
(),
'params'
:
dict
(
zip
(
self
.
keys
,
self
.
params
[
self
.
target
.
argmax
()])
)
}
except
ValueError
:
res
=
{}
return
res
def
res
(
self
):
"""Get all target values found and corresponding parametes."""
params
=
[
dict
(
zip
(
self
.
keys
,
p
))
for
p
in
self
.
params
]
return
[
{
"target"
:
target
,
"params"
:
param
}
for
target
,
param
in
zip
(
self
.
target
,
params
)
]
src/sdk/pynni/nni/gp_tuner/util.py
0 → 100644
View file @
9fb25ccc
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
gp_tuner.py
'''
import
warnings
import
numpy
as
np
from
scipy.stats
import
norm
from
scipy.optimize
import
minimize
def
_match_val_type
(
vals
,
bounds
):
'''
Update values in the array, to match their corresponding type
'''
vals_new
=
[]
for
i
,
bound
in
enumerate
(
bounds
):
_type
=
bound
[
'_type'
]
if
_type
==
"choice"
:
# Find the closest integer in the array, vals_bounds
vals_new
.
append
(
min
(
bound
[
'_value'
],
key
=
lambda
x
:
abs
(
x
-
vals
[
i
])))
elif
_type
in
[
'quniform'
,
'randint'
]:
vals_new
.
append
(
np
.
around
(
vals
[
i
]))
else
:
vals_new
.
append
(
vals
[
i
])
return
vals_new
def
acq_max
(
f_acq
,
gp
,
y_max
,
bounds
,
space
,
num_warmup
,
num_starting_points
):
"""
A function to find the maximum of the acquisition function
It uses a combination of random sampling (cheap) and the 'L-BFGS-B'
optimization method. First by sampling `n_warmup` (1e5) points at random,
and then running L-BFGS-B from `n_iter` (250) random starting points.
Parameters
----------
:param f_acq:
The acquisition function object that return its point-wise value.
:param gp:
A gaussian process fitted to the relevant data.
:param y_max:
The current maximum known value of the target function.
:param bounds:
The variables bounds to limit the search of the acq max.
:param num_warmup:
number of times to randomly sample the aquisition function
:param num_starting_points:
number of times to run scipy.minimize
Returns
-------
:return: x_max, The arg max of the acquisition function.
"""
# Warm up with random points
x_tries
=
[
space
.
random_sample
()
for
_
in
range
(
int
(
num_warmup
))]
ys
=
f_acq
(
x_tries
,
gp
=
gp
,
y_max
=
y_max
)
x_max
=
x_tries
[
ys
.
argmax
()]
max_acq
=
ys
.
max
()
# Explore the parameter space more throughly
x_seeds
=
[
space
.
random_sample
()
for
_
in
range
(
int
(
num_starting_points
))]
bounds_minmax
=
np
.
array
(
[[
bound
[
'_value'
][
0
],
bound
[
'_value'
][
-
1
]]
for
bound
in
bounds
])
for
x_try
in
x_seeds
:
# Find the minimum of minus the acquisition function
res
=
minimize
(
lambda
x
:
-
f_acq
(
x
.
reshape
(
1
,
-
1
),
gp
=
gp
,
y_max
=
y_max
),
x_try
.
reshape
(
1
,
-
1
),
bounds
=
bounds_minmax
,
method
=
"L-BFGS-B"
)
# See if success
if
not
res
.
success
:
continue
# Store it if better than previous minimum(maximum).
if
max_acq
is
None
or
-
res
.
fun
[
0
]
>=
max_acq
:
x_max
=
_match_val_type
(
res
.
x
,
bounds
)
max_acq
=
-
res
.
fun
[
0
]
# Clip output to make sure it lies within the bounds. Due to floating
# point technicalities this is not always the case.
return
np
.
clip
(
x_max
,
bounds_minmax
[:,
0
],
bounds_minmax
[:,
1
])
class
UtilityFunction
():
"""
An object to compute the acquisition functions.
"""
def
__init__
(
self
,
kind
,
kappa
,
xi
):
"""
If UCB is to be used, a constant kappa is needed.
"""
self
.
kappa
=
kappa
self
.
xi
=
xi
if
kind
not
in
[
'ucb'
,
'ei'
,
'poi'
]:
err
=
"The utility function "
\
"{} has not been implemented, "
\
"please choose one of ucb, ei, or poi."
.
format
(
kind
)
raise
NotImplementedError
(
err
)
self
.
kind
=
kind
def
utility
(
self
,
x
,
gp
,
y_max
):
'''return utility function'''
if
self
.
kind
==
'ucb'
:
return
self
.
_ucb
(
x
,
gp
,
self
.
kappa
)
if
self
.
kind
==
'ei'
:
return
self
.
_ei
(
x
,
gp
,
y_max
,
self
.
xi
)
if
self
.
kind
==
'poi'
:
return
self
.
_poi
(
x
,
gp
,
y_max
,
self
.
xi
)
return
None
@
staticmethod
def
_ucb
(
x
,
gp
,
kappa
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
mean
,
std
=
gp
.
predict
(
x
,
return_std
=
True
)
return
mean
+
kappa
*
std
@
staticmethod
def
_ei
(
x
,
gp
,
y_max
,
xi
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
mean
,
std
=
gp
.
predict
(
x
,
return_std
=
True
)
z
=
(
mean
-
y_max
-
xi
)
/
std
return
(
mean
-
y_max
-
xi
)
*
norm
.
cdf
(
z
)
+
std
*
norm
.
pdf
(
z
)
@
staticmethod
def
_poi
(
x
,
gp
,
y_max
,
xi
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
mean
,
std
=
gp
.
predict
(
x
,
return_std
=
True
)
z
=
(
mean
-
y_max
-
xi
)
/
std
return
norm
.
cdf
(
z
)
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
9fb25ccc
...
@@ -30,8 +30,8 @@ import json_tricks
...
@@ -30,8 +30,8 @@ import json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.common
import
init_logger
from
nni.common
import
init_logger
,
multi_phase_enabled
from
nni.utils
import
NodeType
,
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.utils
import
NodeType
,
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
import
nni.parameter_expressions
as
parameter_expressions
import
nni.parameter_expressions
as
parameter_expressions
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -144,7 +144,7 @@ class Bracket():
...
@@ -144,7 +144,7 @@ class Bracket():
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
self
.
num_configs_to_run
=
[]
# [ n, n, n, ... ]
self
.
num_configs_to_run
=
[]
# [ n, n, n, ... ]
self
.
num_finished_configs
=
[]
# [ n, n, n, ... ]
self
.
num_finished_configs
=
[]
# [ n, n, n, ... ]
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
no_more_trial
=
False
self
.
no_more_trial
=
False
def
is_completed
(
self
):
def
is_completed
(
self
):
...
@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
optimize mode, 'maximize' or 'minimize'
"""
"""
def
__init__
(
self
,
R
,
eta
=
3
,
optimize_mode
=
'maximize'
):
def
__init__
(
self
,
R
=
60
,
eta
=
3
,
optimize_mode
=
'maximize'
):
"""B = (s_max + 1)R"""
"""B = (s_max + 1)R"""
super
(
Hyperband
,
self
).
__init__
()
super
(
Hyperband
,
self
).
__init__
()
self
.
R
=
R
# pylint: disable=invalid-name
self
.
R
=
R
# pylint: disable=invalid-name
...
@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
...
@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
# In this case, tuner increases self.credit to issue a trial config sometime later.
# In this case, tuner increases self.credit to issue a trial config sometime later.
self
.
credit
=
0
self
.
credit
=
0
def
load_checkpoint
(
self
):
# record the latest parameter_id of the trial job trial_job_id.
pass
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def
save_checkpoint
(
self
):
self
.
job_id_para_id_map
=
dict
()
pass
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""data is search space
"""data is search space
...
@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
...
@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs
number of trial jobs
"""
"""
for
_
in
range
(
data
):
for
_
in
range
(
data
):
self
.
_request_one_trial_job
()
ret
=
self
.
_get_one_trial_job
()
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration."""
"""get one trial job, i.e., one hyperparameter configuration."""
if
not
self
.
generated_hyper_configs
:
if
not
self
.
generated_hyper_configs
:
if
self
.
curr_s
<
0
:
if
self
.
curr_s
<
0
:
...
@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
...
@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
'parameters'
:
params
[
1
]
}
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
ret
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
"""data: JSON object, which is search space
...
@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
...
@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
randint_to_quniform
(
self
.
searchspace_json
)
randint_to_quniform
(
self
.
searchspace_json
)
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
def
_handle_trial_end
(
self
,
parameter_id
):
"""
Parameters
----------
parameter_id: parameter id of the finished config
"""
bracket_id
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""
"""
Parameters
Parameters
...
@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
...
@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
bracket_id
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
if
hyper_configs
is
not
None
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
if
not
self
.
generated_hyper_configs
:
break
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
'parameter_id'
:
params
[
0
],
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
...
@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase):
...
@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase):
ValueError
ValueError
Data type not supported
Data type not supported
"""
"""
value
=
extract_scalar_reward
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
assert
multi_phase_enabled
()
bracket_id
=
int
(
bracket_id
)
assert
data
[
'trial_job_id'
]
is
not
None
if
data
[
'type'
]
==
'FINAL'
:
assert
data
[
'parameter_index'
]
is
not
None
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
# and PERIODICAL metric are independent, thus, not comparable.
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
ret
=
self
.
_get_one_trial_job
()
self
.
completed_hyper_configs
.
append
(
data
)
if
data
[
'trial_job_id'
]
is
not
None
:
elif
data
[
'type'
]
==
'PERIODICAL'
:
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
bracket_id
=
int
(
bracket_id
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
# and PERIODICAL metric are independent, thus, not comparable.
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
pass
...
...
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
View file @
9fb25ccc
...
@@ -49,15 +49,16 @@ def selection_r(x_bounds,
...
@@ -49,15 +49,16 @@ def selection_r(x_bounds,
num_starting_points
=
100
,
num_starting_points
=
100
,
minimize_constraints_fun
=
None
):
minimize_constraints_fun
=
None
):
'''
'''
Call selection
Select using different types.
'''
'''
minimize_starting_points
=
[
lib_data
.
rand
(
x_bounds
,
x_type
s
)
\
minimize_starting_points
=
clusteringmodel_gmm_good
.
sample
(
n_samples
=
num_starting_point
s
)
for
i
in
range
(
0
,
num_starting_points
)]
outputs
=
selection
(
x_bounds
,
x_types
,
outputs
=
selection
(
x_bounds
,
x_types
,
clusteringmodel_gmm_good
,
clusteringmodel_gmm_good
,
clusteringmodel_gmm_bad
,
clusteringmodel_gmm_bad
,
minimize_starting_points
,
minimize_starting_points
[
0
]
,
minimize_constraints_fun
)
minimize_constraints_fun
)
return
outputs
return
outputs
def
selection
(
x_bounds
,
def
selection
(
x_bounds
,
...
...
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
9fb25ccc
...
@@ -20,15 +20,15 @@
...
@@ -20,15 +20,15 @@
import
copy
import
copy
import
logging
import
logging
import
numpy
as
np
import
os
import
os
import
random
import
random
import
statistics
import
statistics
import
sys
import
sys
import
warnings
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
multiprocessing.dummy
import
Pool
as
ThreadPool
import
numpy
as
np
import
nni.metis_tuner.lib_constraint_summation
as
lib_constraint_summation
import
nni.metis_tuner.lib_constraint_summation
as
lib_constraint_summation
import
nni.metis_tuner.lib_data
as
lib_data
import
nni.metis_tuner.lib_data
as
lib_data
import
nni.metis_tuner.Regression_GMM.CreateModel
as
gmm_create_model
import
nni.metis_tuner.Regression_GMM.CreateModel
as
gmm_create_model
...
@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
...
@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
logger
=
logging
.
getLogger
(
"Metis_Tuner_AutoML"
)
logger
=
logging
.
getLogger
(
"Metis_Tuner_AutoML"
)
NONE_TYPE
=
''
NONE_TYPE
=
''
CONSTRAINT_LOWERBOUND
=
None
CONSTRAINT_LOWERBOUND
=
None
CONSTRAINT_UPPERBOUND
=
None
CONSTRAINT_UPPERBOUND
=
None
...
@@ -93,7 +91,7 @@ class MetisTuner(Tuner):
...
@@ -93,7 +91,7 @@ class MetisTuner(Tuner):
self
.
space
=
None
self
.
space
=
None
self
.
no_resampling
=
no_resampling
self
.
no_resampling
=
no_resampling
self
.
no_candidates
=
no_candidates
self
.
no_candidates
=
no_candidates
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
key_order
=
[]
self
.
key_order
=
[]
self
.
cold_start_num
=
cold_start_num
self
.
cold_start_num
=
cold_start_num
self
.
selection_num_starting_points
=
selection_num_starting_points
self
.
selection_num_starting_points
=
selection_num_starting_points
...
@@ -254,6 +252,9 @@ class MetisTuner(Tuner):
...
@@ -254,6 +252,9 @@ class MetisTuner(Tuner):
threshold_samplessize_resampling
=
50
,
no_candidates
=
False
,
threshold_samplessize_resampling
=
50
,
no_candidates
=
False
,
minimize_starting_points
=
None
,
minimize_constraints_fun
=
None
):
minimize_starting_points
=
None
,
minimize_constraints_fun
=
None
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
next_candidate
=
None
next_candidate
=
None
candidates
=
[]
candidates
=
[]
samples_size_all
=
sum
([
len
(
i
)
for
i
in
samples_y
])
samples_size_all
=
sum
([
len
(
i
)
for
i
in
samples_y
])
...
@@ -271,13 +272,12 @@ class MetisTuner(Tuner):
...
@@ -271,13 +272,12 @@ class MetisTuner(Tuner):
minimize_constraints_fun
=
minimize_constraints_fun
)
minimize_constraints_fun
=
minimize_constraints_fun
)
if
not
lm_current
:
if
not
lm_current
:
return
None
return
None
logger
.
info
({
'hyperparameter'
:
lm_current
[
'hyperparameter'
],
if
no_candidates
is
False
:
candidates
.
append
({
'hyperparameter'
:
lm_current
[
'hyperparameter'
],
'expected_mu'
:
lm_current
[
'expected_mu'
],
'expected_mu'
:
lm_current
[
'expected_mu'
],
'expected_sigma'
:
lm_current
[
'expected_sigma'
],
'expected_sigma'
:
lm_current
[
'expected_sigma'
],
'reason'
:
"exploitation_gp"
})
'reason'
:
"exploitation_gp"
})
if
no_candidates
is
False
:
# ===== STEP 2: Get recommended configurations for exploration =====
# ===== STEP 2: Get recommended configurations for exploration =====
results_exploration
=
gp_selection
.
selection
(
results_exploration
=
gp_selection
.
selection
(
"lc"
,
"lc"
,
...
@@ -290,34 +290,48 @@ class MetisTuner(Tuner):
...
@@ -290,34 +290,48 @@ class MetisTuner(Tuner):
if
results_exploration
is
not
None
:
if
results_exploration
is
not
None
:
if
_num_past_samples
(
results_exploration
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
if
_num_past_samples
(
results_exploration
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
candidate
s
.
append
(
{
'hyperparameter'
:
results_exploration
[
'hyperparameter'
],
temp_
candidate
=
{
'hyperparameter'
:
results_exploration
[
'hyperparameter'
],
'expected_mu'
:
results_exploration
[
'expected_mu'
],
'expected_mu'
:
results_exploration
[
'expected_mu'
],
'expected_sigma'
:
results_exploration
[
'expected_sigma'
],
'expected_sigma'
:
results_exploration
[
'expected_sigma'
],
'reason'
:
"exploration"
})
'reason'
:
"exploration"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: 1 exploration candidate selected
\n
"
)
logger
.
info
(
"DEBUG: 1 exploration candidate selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
else
:
logger
.
info
(
"DEBUG: No suitable exploration candidates were"
)
logger
.
info
(
"DEBUG: No suitable exploration candidates were"
)
# ===== STEP 3: Get recommended configurations for exploitation =====
# ===== STEP 3: Get recommended configurations for exploitation =====
if
samples_size_all
>=
threshold_samplessize_exploitation
:
if
samples_size_all
>=
threshold_samplessize_exploitation
:
print
(
"Getting candidates for exploitation...
\n
"
)
logger
.
info
(
"Getting candidates for exploitation...
\n
"
)
try
:
try
:
gmm
=
gmm_create_model
.
create_model
(
samples_x
,
samples_y_aggregation
)
gmm
=
gmm_create_model
.
create_model
(
samples_x
,
samples_y_aggregation
)
results_exploitation
=
gmm_selection
.
selection
(
x_bounds
,
if
(
"discrete_int"
in
x_types
)
or
(
"range_int"
in
x_types
):
x_types
,
results_exploitation
=
gmm_selection
.
selection
(
x_bounds
,
x_types
,
gmm
[
'clusteringmodel_good'
],
gmm
[
'clusteringmodel_good'
],
gmm
[
'clusteringmodel_bad'
],
gmm
[
'clusteringmodel_bad'
],
minimize_starting_points
,
minimize_starting_points
,
minimize_constraints_fun
=
minimize_constraints_fun
)
minimize_constraints_fun
=
minimize_constraints_fun
)
else
:
# If all parameters are of "range_continuous", let's use GMM to generate random starting points
results_exploitation
=
gmm_selection
.
selection_r
(
x_bounds
,
x_types
,
gmm
[
'clusteringmodel_good'
],
gmm
[
'clusteringmodel_bad'
],
num_starting_points
=
self
.
selection_num_starting_points
,
minimize_constraints_fun
=
minimize_constraints_fun
)
if
results_exploitation
is
not
None
:
if
results_exploitation
is
not
None
:
if
_num_past_samples
(
results_exploitation
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
if
_num_past_samples
(
results_exploitation
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
candidates
.
append
({
'hyperparameter'
:
results_exploitation
[
'hyperparameter'
],
\
temp_expected_mu
,
temp_expected_sigma
=
gp_prediction
.
predict
(
results_exploitation
[
'hyperparameter'
],
gp_model
[
'model'
])
'expected_mu'
:
results_exploitation
[
'expected_mu'
],
\
temp_candidate
=
{
'hyperparameter'
:
results_exploitation
[
'hyperparameter'
],
'expected_sigma'
:
results_exploitation
[
'expected_sigma'
],
\
'expected_mu'
:
temp_expected_mu
,
'reason'
:
"exploitation_gmm"
})
'expected_sigma'
:
temp_expected_sigma
,
'reason'
:
"exploitation_gmm"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: 1 exploitation_gmm candidate selected
\n
"
)
logger
.
info
(
"DEBUG: 1 exploitation_gmm candidate selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
else
:
logger
.
info
(
"DEBUG: No suitable exploitation_gmm candidates were found
\n
"
)
logger
.
info
(
"DEBUG: No suitable exploitation_gmm candidates were found
\n
"
)
...
@@ -338,11 +352,13 @@ class MetisTuner(Tuner):
...
@@ -338,11 +352,13 @@ class MetisTuner(Tuner):
if
results_outliers
is
not
None
:
if
results_outliers
is
not
None
:
for
results_outlier
in
results_outliers
:
for
results_outlier
in
results_outliers
:
if
_num_past_samples
(
samples_x
[
results_outlier
[
'samples_idx'
]],
samples_x
,
samples_y
)
<
max_resampling_per_x
:
if
_num_past_samples
(
samples_x
[
results_outlier
[
'samples_idx'
]],
samples_x
,
samples_y
)
<
max_resampling_per_x
:
candidate
s
.
append
(
{
'hyperparameter'
:
samples_x
[
results_outlier
[
'samples_idx'
]],
\
temp_
candidate
=
{
'hyperparameter'
:
samples_x
[
results_outlier
[
'samples_idx'
]],
\
'expected_mu'
:
results_outlier
[
'expected_mu'
],
\
'expected_mu'
:
results_outlier
[
'expected_mu'
],
\
'expected_sigma'
:
results_outlier
[
'expected_sigma'
],
\
'expected_sigma'
:
results_outlier
[
'expected_sigma'
],
\
'reason'
:
"resampling"
})
'reason'
:
"resampling"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: %d re-sampling candidates selected
\n
"
)
logger
.
info
(
"DEBUG: %d re-sampling candidates selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
else
:
logger
.
info
(
"DEBUG: No suitable resampling candidates were found
\n
"
)
logger
.
info
(
"DEBUG: No suitable resampling candidates were found
\n
"
)
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
9fb25ccc
...
@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
...
@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
from
.assessor
import
AssessResult
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.env_vars
import
dispatcher_env_vars
from
.utils
import
MetricType
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
"""
if
data
[
'type'
]
==
'
FINAL
'
:
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
'
PERIODICAL
'
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
self
.
_handle_intermediate_metric_data
(
data
)
elif
data
[
'type'
]
==
'
REQUEST_PARAMETER
'
:
elif
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
...
@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
def
_handle_intermediate_metric_data
(
self
,
data
):
def
_handle_intermediate_metric_data
(
self
,
data
):
"""Call assessor to process intermediate results
"""Call assessor to process intermediate results
"""
"""
if
data
[
'type'
]
!=
'
PERIODICAL
'
:
if
data
[
'type'
]
!=
MetricType
.
PERIODICAL
:
return
return
if
self
.
assessor
is
None
:
if
self
.
assessor
is
None
:
return
return
...
@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
trial is early stopped.
trial is early stopped.
"""
"""
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
data
[
'type'
]
=
'
FINAL
'
data
[
'type'
]
=
MetricType
.
FINAL
if
multi_thread_enabled
():
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
else
:
else
:
...
...
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment