Unverified Commit 9fb25ccc authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #189 from microsoft/master

merge master
parents 1500458a 7c4bc33b
...@@ -9,7 +9,7 @@ searchSpacePath: search_space.json ...@@ -9,7 +9,7 @@ searchSpacePath: search_space.json
#choice: true, false #choice: true, false
useAnnotation: false useAnnotation: false
tuner: tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl) #SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE builtinTunerName: TPE
classArgs: classArgs:
......
...@@ -32,7 +32,7 @@ trainingServicePlatform: local ...@@ -32,7 +32,7 @@ trainingServicePlatform: local
useAnnotation: false useAnnotation: false
tuner: tuner:
#可选项: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism #可选项: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC 需要通过 nnictl 安装) #SMAC (SMAC 需要通过 nnictl 安装)
builtinTunerName: NetworkMorphism builtinTunerName: NetworkMorphism
classArgs: classArgs:
#可选项: maximize, minimize #可选项: maximize, minimize
......
# 如何使用 ga_customer_tuner? # 如何使用 ga_customer_tuner?
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad", 输入 `cd ~/nni/examples/trials/ga_squad` 查看 readme.md 来了解 ga_squad 的更多信息。 此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",输入 `cd ~/nni/examples/trials/ga_squad` 查看 readme.md 来了解 ga_squad 的更多信息。
# 配置 # 配置
......
# 如何使用 ga_customer_tuner? # 如何使用 ga_customer_tuner?
此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad", 输入 `cd ~/nni/examples/trials/ga_squad` 查看 readme.md 来了解 ga_squad 的更多信息。 此定制的 Tuner 仅适用于代码 "~/nni/examples/trials/ga_squad",输入 `cd ~/nni/examples/trials/ga_squad` 查看 readme.md 来了解 ga_squad 的更多信息。
# 配置 # 配置
......
...@@ -54,6 +54,10 @@ ...@@ -54,6 +54,10 @@
"tslint-microsoft-contrib": "^6.0.0", "tslint-microsoft-contrib": "^6.0.0",
"typescript": "^3.2.2" "typescript": "^3.2.2"
}, },
"resolutions": {
"mem": "^4.0.0",
"handlebars": "^4.1.0"
},
"engines": { "engines": {
"node": ">=10.0.0" "node": ">=10.0.0"
}, },
......
...@@ -51,6 +51,7 @@ export namespace ValidationSchemas { ...@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
command: joi.string().min(1), command: joi.string().min(1),
virtualCluster: joi.string(), virtualCluster: joi.string(),
shmMB: joi.number(), shmMB: joi.number(),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode'),
worker: joi.object({ worker: joi.object({
replicas: joi.number().min(1).required(), replicas: joi.number().min(1).required(),
image: joi.string().min(1), image: joi.string().min(1),
...@@ -161,7 +162,7 @@ export namespace ValidationSchemas { ...@@ -161,7 +162,7 @@ export namespace ValidationSchemas {
checkpointDir: joi.string().allow('') checkpointDir: joi.string().allow('')
}), }),
tuner: joi.object({ tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner'), builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'),
codeDir: joi.string(), codeDir: joi.string(),
classFileName: joi.string(), classFileName: joi.string(),
className: joi.string(), className: joi.string(),
......
...@@ -145,6 +145,10 @@ ...@@ -145,6 +145,10 @@
"@types/minimatch" "*" "@types/minimatch" "*"
"@types/node" "*" "@types/node" "*"
"@types/js-base64@^2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@types/js-base64/-/js-base64-2.3.1.tgz#c39f14f129408a3d96a1105a650d8b2b6eeb4168"
"@types/mime@*": "@types/mime@*":
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b" resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
...@@ -161,9 +165,9 @@ ...@@ -161,9 +165,9 @@
version "10.5.2" version "10.5.2"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.5.2.tgz#f19f05314d5421fe37e74153254201a7bf00a707" resolved "https://registry.yarnpkg.com/@types/node/-/node-10.5.2.tgz#f19f05314d5421fe37e74153254201a7bf00a707"
"@types/node@^10.5.5": "@types/node@10.12.18":
version "10.5.5" version "10.12.18"
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.5.5.tgz#8e84d24e896cd77b0d4f73df274027e3149ec2ba" resolved "https://registry.yarnpkg.com/@types/node/-/node-10.12.18.tgz#1d3ca764718915584fcd9f6344621b7672665c67"
"@types/range-parser@*": "@types/range-parser@*":
version "1.2.2" version "1.2.2"
...@@ -342,10 +346,6 @@ ansi-regex@^3.0.0: ...@@ -342,10 +346,6 @@ ansi-regex@^3.0.0:
version "3.0.0" version "3.0.0"
resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-3.0.0.tgz#ed0317c322064f79466c02966bddb605ab37d998" resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-3.0.0.tgz#ed0317c322064f79466c02966bddb605ab37d998"
ansi-styles@^2.2.1:
version "2.2.1"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-2.2.1.tgz#b432dd3358b634cf75e1e4664368240533c1ddbe"
ansi-styles@^3.2.1: ansi-styles@^3.2.1:
version "3.2.1" version "3.2.1"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d"
...@@ -413,12 +413,6 @@ async-limiter@~1.0.0: ...@@ -413,12 +413,6 @@ async-limiter@~1.0.0:
version "1.0.0" version "1.0.0"
resolved "https://registry.yarnpkg.com/async-limiter/-/async-limiter-1.0.0.tgz#78faed8c3d074ab81f22b4e985d79e8738f720f8" resolved "https://registry.yarnpkg.com/async-limiter/-/async-limiter-1.0.0.tgz#78faed8c3d074ab81f22b4e985d79e8738f720f8"
async@^2.5.0:
version "2.6.1"
resolved "https://registry.yarnpkg.com/async/-/async-2.6.1.tgz#b245a23ca71930044ec53fa46aa00a3e87c6a610"
dependencies:
lodash "^4.17.10"
asynckit@^0.4.0: asynckit@^0.4.0:
version "0.4.0" version "0.4.0"
resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
...@@ -451,14 +445,6 @@ azure-storage@^2.10.2: ...@@ -451,14 +445,6 @@ azure-storage@^2.10.2:
xml2js "0.2.8" xml2js "0.2.8"
xmlbuilder "^9.0.7" xmlbuilder "^9.0.7"
babel-code-frame@^6.22.0:
version "6.26.0"
resolved "https://registry.yarnpkg.com/babel-code-frame/-/babel-code-frame-6.26.0.tgz#63fd43f7dc1e3bb7ce35947db8fe369a3f58c74b"
dependencies:
chalk "^1.1.3"
esutils "^2.0.2"
js-tokens "^3.0.2"
balanced-match@^1.0.0: balanced-match@^1.0.0:
version "1.0.0" version "1.0.0"
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.0.tgz#89b4d199ab2bee49de164ea02b89ce462d71b767" resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.0.tgz#89b4d199ab2bee49de164ea02b89ce462d71b767"
...@@ -575,16 +561,6 @@ chai@^4.1.2: ...@@ -575,16 +561,6 @@ chai@^4.1.2:
pathval "^1.0.0" pathval "^1.0.0"
type-detect "^4.0.0" type-detect "^4.0.0"
chalk@^1.1.3:
version "1.1.3"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-1.1.3.tgz#a8115c55e4a702fe4d150abd3872822a7e09fc98"
dependencies:
ansi-styles "^2.2.1"
escape-string-regexp "^1.0.2"
has-ansi "^2.0.0"
strip-ansi "^3.0.0"
supports-color "^2.0.0"
chalk@^2.0.0, chalk@^2.3.0: chalk@^2.0.0, chalk@^2.3.0:
version "2.4.1" version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
...@@ -854,7 +830,7 @@ escape-html@~1.0.3: ...@@ -854,7 +830,7 @@ escape-html@~1.0.3:
version "1.0.3" version "1.0.3"
resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988" resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
escape-string-regexp@1.0.5, escape-string-regexp@^1.0.2, escape-string-regexp@^1.0.5: escape-string-regexp@1.0.5, escape-string-regexp@^1.0.5:
version "1.0.5" version "1.0.5"
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4"
...@@ -1135,11 +1111,11 @@ growl@1.10.5: ...@@ -1135,11 +1111,11 @@ growl@1.10.5:
version "1.10.5" version "1.10.5"
resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e" resolved "https://registry.yarnpkg.com/growl/-/growl-1.10.5.tgz#f2735dc2283674fa67478b10181059355c369e5e"
handlebars@^4.0.11: handlebars@^4.0.11, handlebars@^4.1.0:
version "4.0.12" version "4.1.2"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.0.12.tgz#2c15c8a96d46da5e266700518ba8cb8d919d5bc5" resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.1.2.tgz#b6b37c1ced0306b221e094fc7aca3ec23b131b67"
dependencies: dependencies:
async "^2.5.0" neo-async "^2.6.0"
optimist "^0.6.1" optimist "^0.6.1"
source-map "^0.6.1" source-map "^0.6.1"
optionalDependencies: optionalDependencies:
...@@ -1163,12 +1139,6 @@ har-validator@~5.1.0: ...@@ -1163,12 +1139,6 @@ har-validator@~5.1.0:
ajv "^5.3.0" ajv "^5.3.0"
har-schema "^2.0.0" har-schema "^2.0.0"
has-ansi@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/has-ansi/-/has-ansi-2.0.0.tgz#34f5049ce1ecdf2b0649af3ef24e45ed35416d91"
dependencies:
ansi-regex "^2.0.0"
has-flag@^3.0.0: has-flag@^3.0.0:
version "3.0.0" version "3.0.0"
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd" resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd"
...@@ -1426,21 +1396,24 @@ js-base64@^2.4.9: ...@@ -1426,21 +1396,24 @@ js-base64@^2.4.9:
version "2.5.0" version "2.5.0"
resolved "https://registry.yarnpkg.com/js-base64/-/js-base64-2.5.0.tgz#42255ba183ab67ce59a0dee640afdc00ab5ae93e" resolved "https://registry.yarnpkg.com/js-base64/-/js-base64-2.5.0.tgz#42255ba183ab67ce59a0dee640afdc00ab5ae93e"
js-tokens@^3.0.2:
version "3.0.2"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-3.0.2.tgz#9866df395102130e38f7f996bceb65443209c25b"
js-tokens@^4.0.0: js-tokens@^4.0.0:
version "4.0.0" version "4.0.0"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
js-yaml@^3.10.0, js-yaml@^3.7.0: js-yaml@^3.10.0:
version "3.12.0" version "3.12.0"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.12.0.tgz#eaed656ec8344f10f527c6bfa1b6e2244de167d1" resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.12.0.tgz#eaed656ec8344f10f527c6bfa1b6e2244de167d1"
dependencies: dependencies:
argparse "^1.0.7" argparse "^1.0.7"
esprima "^4.0.0" esprima "^4.0.0"
js-yaml@^3.13.1:
version "3.13.1"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.13.1.tgz#aff151b30bfdfa8e49e05da22e7415e9dfa37847"
dependencies:
argparse "^1.0.7"
esprima "^4.0.0"
jsbn@~0.1.0: jsbn@~0.1.0:
version "0.1.1" version "0.1.1"
resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513" resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513"
...@@ -1619,6 +1592,12 @@ make-error@^1.1.1: ...@@ -1619,6 +1592,12 @@ make-error@^1.1.1:
version "1.3.4" version "1.3.4"
resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.4.tgz#19978ed575f9e9545d2ff8c13e33b5d18a67d535" resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.4.tgz#19978ed575f9e9545d2ff8c13e33b5d18a67d535"
map-age-cleaner@^0.1.1:
version "0.1.3"
resolved "https://registry.yarnpkg.com/map-age-cleaner/-/map-age-cleaner-0.1.3.tgz#7d583a7306434c055fe474b0f45078e6e1b4b92a"
dependencies:
p-defer "^1.0.0"
md5-hex@^2.0.0: md5-hex@^2.0.0:
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/md5-hex/-/md5-hex-2.0.0.tgz#d0588e9f1c74954492ecd24ac0ac6ce997d92e33" resolved "https://registry.yarnpkg.com/md5-hex/-/md5-hex-2.0.0.tgz#d0588e9f1c74954492ecd24ac0ac6ce997d92e33"
...@@ -1640,11 +1619,13 @@ media-typer@0.3.0: ...@@ -1640,11 +1619,13 @@ media-typer@0.3.0:
version "0.3.0" version "0.3.0"
resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748" resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
mem@^1.1.0: mem@^1.1.0, mem@^4.0.0:
version "1.1.0" version "4.3.0"
resolved "https://registry.yarnpkg.com/mem/-/mem-1.1.0.tgz#5edd52b485ca1d900fe64895505399a0dfa45f76" resolved "https://registry.yarnpkg.com/mem/-/mem-4.3.0.tgz#461af497bc4ae09608cdb2e60eefb69bff744178"
dependencies: dependencies:
mimic-fn "^1.0.0" map-age-cleaner "^0.1.1"
mimic-fn "^2.0.0"
p-is-promise "^2.0.0"
merge-descriptors@1.0.1: merge-descriptors@1.0.1:
version "1.0.1" version "1.0.1"
...@@ -1684,9 +1665,9 @@ mime@1.4.1: ...@@ -1684,9 +1665,9 @@ mime@1.4.1:
version "1.4.1" version "1.4.1"
resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6" resolved "https://registry.yarnpkg.com/mime/-/mime-1.4.1.tgz#121f9ebc49e3766f311a76e1fa1c8003c4b03aa6"
mimic-fn@^1.0.0: mimic-fn@^2.0.0:
version "1.2.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022" resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-2.1.0.tgz#7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b"
mimic-response@^1.0.0: mimic-response@^1.0.0:
version "1.0.1" version "1.0.1"
...@@ -1773,6 +1754,10 @@ negotiator@0.6.1: ...@@ -1773,6 +1754,10 @@ negotiator@0.6.1:
version "0.6.1" version "0.6.1"
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.1.tgz#2b327184e8992101177b28563fb5e7102acd0ca9" resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.1.tgz#2b327184e8992101177b28563fb5e7102acd0ca9"
neo-async@^2.6.0:
version "2.6.1"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.1.tgz#ac27ada66167fa8849a6addd837f6b189ad2081c"
node-forge@^0.7.6: node-forge@^0.7.6:
version "0.7.6" version "0.7.6"
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac" resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac"
...@@ -1797,12 +1782,6 @@ node-jose@^1.1.0: ...@@ -1797,12 +1782,6 @@ node-jose@^1.1.0:
node-forge "^0.7.6" node-forge "^0.7.6"
uuid "^3.3.2" uuid "^3.3.2"
node-nvidia-smi@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/node-nvidia-smi/-/node-nvidia-smi-1.0.0.tgz#6aa57574540b2bed91c9a80218516ffa686e5ac7"
dependencies:
xml2js "^0.4.17"
node-pre-gyp@^0.10.3: node-pre-gyp@^0.10.3:
version "0.10.3" version "0.10.3"
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc" resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc"
...@@ -2005,6 +1984,10 @@ p-cancelable@^0.4.0: ...@@ -2005,6 +1984,10 @@ p-cancelable@^0.4.0:
version "0.4.1" version "0.4.1"
resolved "http://registry.npmjs.org/p-cancelable/-/p-cancelable-0.4.1.tgz#35f363d67d52081c8d9585e37bcceb7e0bbcb2a0" resolved "http://registry.npmjs.org/p-cancelable/-/p-cancelable-0.4.1.tgz#35f363d67d52081c8d9585e37bcceb7e0bbcb2a0"
p-defer@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-defer/-/p-defer-1.0.0.tgz#9f6eb182f6c9aa8cd743004a7d4f96b196b0fb0c"
p-finally@^1.0.0: p-finally@^1.0.0:
version "1.0.0" version "1.0.0"
resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae" resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae"
...@@ -2013,6 +1996,10 @@ p-is-promise@^1.1.0: ...@@ -2013,6 +1996,10 @@ p-is-promise@^1.1.0:
version "1.1.0" version "1.1.0"
resolved "http://registry.npmjs.org/p-is-promise/-/p-is-promise-1.1.0.tgz#9c9456989e9f6588017b0434d56097675c3da05e" resolved "http://registry.npmjs.org/p-is-promise/-/p-is-promise-1.1.0.tgz#9c9456989e9f6588017b0434d56097675c3da05e"
p-is-promise@^2.0.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/p-is-promise/-/p-is-promise-2.1.0.tgz#918cebaea248a62cf7ffab8e3bca8c5f882fc42e"
p-limit@^1.1.0: p-limit@^1.1.0:
version "1.3.0" version "1.3.0"
resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8" resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8"
...@@ -2384,7 +2371,7 @@ sax@0.5.x: ...@@ -2384,7 +2371,7 @@ sax@0.5.x:
version "0.5.8" version "0.5.8"
resolved "http://registry.npmjs.org/sax/-/sax-0.5.8.tgz#d472db228eb331c2506b0e8c15524adb939d12c1" resolved "http://registry.npmjs.org/sax/-/sax-0.5.8.tgz#d472db228eb331c2506b0e8c15524adb939d12c1"
sax@>=0.6.0, sax@^1.2.4: sax@^1.2.4:
version "1.2.4" version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
...@@ -2619,10 +2606,6 @@ supports-color@5.4.0, supports-color@^5.3.0: ...@@ -2619,10 +2606,6 @@ supports-color@5.4.0, supports-color@^5.3.0:
dependencies: dependencies:
has-flag "^3.0.0" has-flag "^3.0.0"
supports-color@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-2.0.0.tgz#535d045ce6b6363fa40117084629995e9df324c7"
supports-color@^5.4.0: supports-color@^5.4.0:
version "5.5.0" version "5.5.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f" resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f"
...@@ -2716,30 +2699,37 @@ tslib@^1.8.0, tslib@^1.8.1: ...@@ -2716,30 +2699,37 @@ tslib@^1.8.0, tslib@^1.8.1:
version "1.9.3" version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
tslint-microsoft-contrib@^5.1.0: tslint-microsoft-contrib@^6.0.0:
version "5.1.0" version "6.2.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-5.1.0.tgz#777c32d51aba16f4565e47aac749a1631176cd9f" resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-6.2.0.tgz#8aa0f40584d066d05e6a5e7988da5163b85f2ad4"
dependencies: dependencies:
tsutils "^2.12.1" tsutils "^2.27.2 <2.29.0"
tslint@^5.11.0: tslint@^5.12.0:
version "5.11.0" version "5.18.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.11.0.tgz#98f30c02eae3cde7006201e4c33cb08b48581eed" resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.18.0.tgz#f61a6ddcf372344ac5e41708095bbf043a147ac6"
dependencies: dependencies:
babel-code-frame "^6.22.0" "@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1" builtin-modules "^1.1.1"
chalk "^2.3.0" chalk "^2.3.0"
commander "^2.12.1" commander "^2.12.1"
diff "^3.2.0" diff "^3.2.0"
glob "^7.1.1" glob "^7.1.1"
js-yaml "^3.7.0" js-yaml "^3.13.1"
minimatch "^3.0.4" minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2" resolve "^1.3.2"
semver "^5.3.0" semver "^5.3.0"
tslib "^1.8.0" tslib "^1.8.0"
tsutils "^2.27.2" tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.12.1, tsutils@^2.27.2: tsutils@^2.29.0:
version "2.29.0" version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99" resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies: dependencies:
...@@ -2777,9 +2767,9 @@ typescript-string-operations@^1.3.1: ...@@ -2777,9 +2767,9 @@ typescript-string-operations@^1.3.1:
version "1.3.1" version "1.3.1"
resolved "https://registry.yarnpkg.com/typescript-string-operations/-/typescript-string-operations-1.3.1.tgz#461b886cc9ccd4dd16810b1f248b2e6f6580956b" resolved "https://registry.yarnpkg.com/typescript-string-operations/-/typescript-string-operations-1.3.1.tgz#461b886cc9ccd4dd16810b1f248b2e6f6580956b"
typescript@^3.0.1: typescript@^3.2.2:
version "3.0.1" version "3.5.2"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.0.1.tgz#43738f29585d3a87575520a4b93ab6026ef11fdb" resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.2.tgz#a09e1dc69bc9551cadf17dba10ee42cf55e5d56c"
uglify-js@^3.1.4: uglify-js@^3.1.4:
version "3.4.9" version "3.4.9"
...@@ -2900,14 +2890,7 @@ xml2js@0.2.8: ...@@ -2900,14 +2890,7 @@ xml2js@0.2.8:
dependencies: dependencies:
sax "0.5.x" sax "0.5.x"
xml2js@^0.4.17: xmlbuilder@^9.0.7:
version "0.4.19"
resolved "https://registry.yarnpkg.com/xml2js/-/xml2js-0.4.19.tgz#686c20f213209e94abf0d1bcf1efaa291c7827a7"
dependencies:
sax ">=0.6.0"
xmlbuilder "~9.0.1"
xmlbuilder@^9.0.7, xmlbuilder@~9.0.1:
version "9.0.7" version "9.0.7"
resolved "http://registry.npmjs.org/xmlbuilder/-/xmlbuilder-9.0.7.tgz#132ee63d2ec5565c557e20f4c22df9aca686b10d" resolved "http://registry.npmjs.org/xmlbuilder/-/xmlbuilder-9.0.7.tgz#132ee63d2ec5565c557e20f4c22df9aca686b10d"
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
from .trial import * from .trial import *
from .smartparam import * from .smartparam import *
from .nas_utils import reload_tensorflow_variables
class NoMoreTrialError(Exception): class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo): def __init__(self,ErrorInfo):
......
...@@ -130,8 +130,6 @@ def main(): ...@@ -130,8 +130,6 @@ def main():
if args.advisor_class_name: if args.advisor_class_name:
# advisor is enabled and starts to run # advisor is enabled and starts to run
if args.multi_phase:
raise AssertionError('multi_phase has not been supported in advisor')
if args.advisor_class_name in AdvisorModuleName: if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance( dispatcher = create_builtin_class_instance(
args.advisor_class_name, args.advisor_class_name,
......
...@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH ...@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.utils import OptimizeMode, extract_scalar_reward, randint_to_quniform from nni.utils import OptimizeMode, MetricType, extract_scalar_reward, randint_to_quniform
from nni.common import multi_phase_enabled
from .config_generator import CG_BOHB from .config_generator import CG_BOHB
...@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return params_id return params_id
class Bracket(): class Bracket(object):
""" """
A bracket in BOHB, all the information of a bracket is managed by A bracket in BOHB, all the information of a bracket is managed by
an instance of this class. an instance of this class.
...@@ -106,7 +107,7 @@ class Bracket(): ...@@ -106,7 +107,7 @@ class Bracket():
self.s_max = s_max self.s_max = s_max
self.eta = eta self.eta = eta
self.max_budget = max_budget self.max_budget = max_budget
self.optimize_mode = optimize_mode self.optimize_mode = OptimizeMode(optimize_mode)
self.n = math.ceil((s_max + 1) * eta**s / (s + 1) - _epsilon) self.n = math.ceil((s_max + 1) * eta**s / (s + 1) - _epsilon)
self.r = max_budget / eta**s self.r = max_budget / eta**s
...@@ -259,33 +260,34 @@ class BOHB(MsgDispatcherBase): ...@@ -259,33 +260,34 @@ class BOHB(MsgDispatcherBase):
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
min_budget: float min_budget: float
The smallest budget to consider. Needs to be positive! The smallest budget to consider. Needs to be positive!
max_budget: float max_budget: float
The largest budget to consider. Needs to be larger than min_budget! The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`. :math:`a^2 + b^2 = c^2 \\sim \\eta^k` for :math:`k\\in [0, 1, ... , num\\_subsets - 1]`.
eta: int eta: int
In each iteration, a complete run of sequential halving is executed. In it, In each iteration, a complete run of sequential halving is executed. In it,
after evaluating each configuration on the same subset size, only a fraction of after evaluating each configuration on the same subset size, only a fraction of
1/eta of them 'advances' to the next round. 1/eta of them 'advances' to the next round.
Must be greater or equal to 2. Must be greater or equal to 2.
min_points_in_model: int min_points_in_model: int
number of observations to start building a KDE. Default 'None' means number of observations to start building a KDE. Default 'None' means
dim+1, the bare minimum. dim+1, the bare minimum.
top_n_percent: int top_n_percent: int
percentage ( between 1 and 99, default 15) of the observations that are considered good. percentage ( between 1 and 99, default 15) of the observations that are considered good.
num_samples: int num_samples: int
number of samples to optimize EI (default 64) number of samples to optimize EI (default 64)
random_fraction: float random_fraction: float
fraction of purely random configurations that are sampled from the fraction of purely random configurations that are sampled from the
prior without the model. prior without the model.
bandwidth_factor: float bandwidth_factor: float
to encourage diversity, the points proposed to optimize EI, are sampled to encourage diversity, the points proposed to optimize EI, are sampled
from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3) from a 'widened' KDE where the bandwidth is multiplied by this factor (default: 3)
min_bandwidth: float min_bandwidth: float
to keep diversity, even when all (good) samples have the same value for one of the parameters, to keep diversity, even when all (good) samples have the same value for one of the parameters,
a minimum bandwidth (Default: 1e-3) is used instead of zero. a minimum bandwidth (Default: 1e-3) is used instead of zero.
""" """
def __init__(self, def __init__(self,
optimize_mode='maximize', optimize_mode='maximize',
min_budget=1, min_budget=1,
...@@ -328,11 +330,12 @@ class BOHB(MsgDispatcherBase): ...@@ -328,11 +330,12 @@ class BOHB(MsgDispatcherBase):
# config generator # config generator
self.cg = None self.cg = None
def load_checkpoint(self): # record the latest parameter_id of the trial job trial_job_id.
pass # if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def save_checkpoint(self): self.job_id_para_id_map = dict()
pass # record the unsatisfied parameter request from trial jobs
self.unsatisfied_jobs = []
def handle_initialize(self, data): def handle_initialize(self, data):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models """Initialize Tuner, including creating Bayesian optimization-based parametric models
...@@ -398,7 +401,7 @@ class BOHB(MsgDispatcherBase): ...@@ -398,7 +401,7 @@ class BOHB(MsgDispatcherBase):
for _ in range(self.credit): for _ in range(self.credit):
self._request_one_trial_job() self._request_one_trial_job()
def _request_one_trial_job(self): def _get_one_trial_job(self):
"""get one trial job, i.e., one hyperparameter configuration. """get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB: If this function is called, Command will be sent by BOHB:
...@@ -422,7 +425,7 @@ class BOHB(MsgDispatcherBase): ...@@ -422,7 +425,7 @@ class BOHB(MsgDispatcherBase):
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
return return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop() params = self.generated_hyper_configs.pop()
ret = { ret = {
...@@ -431,8 +434,29 @@ class BOHB(MsgDispatcherBase): ...@@ -431,8 +434,29 @@ class BOHB(MsgDispatcherBase):
'parameters': params[1] 'parameters': params[1]
} }
self.parameters[params[0]] = params[1] self.parameters[params[0]] = params[1]
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) return ret
self.credit -= 1
def _request_one_trial_job(self):
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{
'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm'
'parameters': value of new hyperparameter
}
b. If BOHB don't have parameter waiting, will return "NoMoreTrialJobs" with
{
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
"""
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
self.credit -= 1
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""change json format to ConfigSpace format dict<dict> -> configspace """change json format to ConfigSpace format dict<dict> -> configspace
...@@ -501,23 +525,38 @@ class BOHB(MsgDispatcherBase): ...@@ -501,23 +525,38 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
logger.debug('Tuner handle trial end, result is %s', data) logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = json_tricks.loads(data['hyper_params'])
s, i, _ = hyper_params['parameter_id'].split('_') self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
def _send_new_trial(self):
while self.unsatisfied_jobs:
ret = self._get_one_trial_job()
if ret is None:
break
one_unsatisfied = self.unsatisfied_jobs.pop(0)
ret['trial_job_id'] = one_unsatisfied['trial_job_id']
ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
for _ in range(self.credit):
self._request_one_trial_job()
def _handle_trial_end(self, parameter_id):
s, i, _ = parameter_id.split('_')
hyper_configs = self.brackets[int(s)].inform_trial_end(int(i)) hyper_configs = self.brackets[int(s)].inform_trial_end(int(i))
if hyper_configs is not None: if hyper_configs is not None:
logger.debug( logger.debug(
'bracket %s next round %s, hyper_configs: %s', s, i, hyper_configs) 'bracket %s next round %s, hyper_configs: %s', s, i, hyper_configs)
self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs
for _ in range(self.credit):
self._request_one_trial_job()
# Finish this bracket and generate a new bracket # Finish this bracket and generate a new bracket
elif self.brackets[int(s)].no_more_trial: elif self.brackets[int(s)].no_more_trial:
self.curr_s -= 1 self.curr_s -= 1
self.generate_new_bracket() self.generate_new_bracket()
for _ in range(self.credit): self._send_new_trial()
self._request_one_trial_job()
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
"""reveice the metric data and update Bayesian optimization with final result """reveice the metric data and update Bayesian optimization with final result
...@@ -534,36 +573,58 @@ class BOHB(MsgDispatcherBase): ...@@ -534,36 +573,58 @@ class BOHB(MsgDispatcherBase):
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
assert 'value' in data if data['type'] == MetricType.REQUEST_PARAMETER:
value = extract_scalar_reward(data['value']) assert multi_phase_enabled()
if self.optimize_mode is OptimizeMode.Maximize: assert data['trial_job_id'] is not None
reward = -value assert data['parameter_index'] is not None
else: assert data['trial_job_id'] in self.job_id_para_id_map
reward = value self._handle_trial_end(self.job_id_para_id_map[data['trial_job_id']])
assert 'parameter_id' in data ret = self._get_one_trial_job()
s, i, _ = data['parameter_id'].split('_') if ret is None:
self.unsatisfied_jobs.append({'trial_job_id': data['trial_job_id'], 'parameter_index': data['parameter_index']})
logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type']) else:
s = int(s) ret['trial_job_id'] = data['trial_job_id']
ret['parameter_index'] = data['parameter_index']
assert 'type' in data # update parameter_id in self.job_id_para_id_map
if data['type'] == 'FINAL': self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
# and PERIODICAL metric are independent, thus, not comparable. send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
assert 'sequence' in data
self.brackets[s].set_config_perf(
int(i), data['parameter_id'], sys.maxsize, value)
self.completed_hyper_configs.append(data)
_parameters = self.parameters[data['parameter_id']]
_parameters.pop(_KEY)
# update BO with loss, max_s budget, hyperparameters
self.cg.new_result(loss=reward, budget=data['sequence'], parameters=_parameters, update_model=True)
elif data['type'] == 'PERIODICAL':
self.brackets[s].set_config_perf(
int(i), data['parameter_id'], data['sequence'], value)
else: else:
raise ValueError( assert 'value' in data
'Data type not supported: {}'.format(data['type'])) value = extract_scalar_reward(data['value'])
if self.optimize_mode is OptimizeMode.Maximize:
reward = -value
else:
reward = value
assert 'parameter_id' in data
s, i, _ = data['parameter_id'].split('_')
logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type'])
s = int(s)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if data['trial_job_id'] in self.job_id_para_id_map:
assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id']
else:
self.job_id_para_id_map[data['trial_job_id']] = data['parameter_id']
assert 'type' in data
if data['type'] == MetricType.FINAL:
# and PERIODICAL metric are independent, thus, not comparable.
assert 'sequence' in data
self.brackets[s].set_config_perf(
int(i), data['parameter_id'], sys.maxsize, value)
self.completed_hyper_configs.append(data)
_parameters = self.parameters[data['parameter_id']]
_parameters.pop(_KEY)
# update BO with loss, max_s budget, hyperparameters
self.cg.new_result(loss=reward, budget=data['sequence'], parameters=_parameters, update_model=True)
elif data['type'] == MetricType.PERIODICAL:
self.brackets[s].set_config_perf(
int(i), data['parameter_id'], data['sequence'], value)
else:
raise ValueError(
'Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass pass
......
...@@ -29,7 +29,8 @@ ModuleName = { ...@@ -29,7 +29,8 @@ ModuleName = {
'GridSearch': 'nni.gridsearch_tuner.gridsearch_tuner', 'GridSearch': 'nni.gridsearch_tuner.gridsearch_tuner',
'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner', 'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner',
'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor', 'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor',
'MetisTuner': 'nni.metis_tuner.metis_tuner' 'MetisTuner': 'nni.metis_tuner.metis_tuner',
'GPTuner': 'nni.gp_tuner.gp_tuner'
} }
ClassName = { ClassName = {
...@@ -42,6 +43,7 @@ ClassName = { ...@@ -42,6 +43,7 @@ ClassName = {
'GridSearch': 'GridSearchTuner', 'GridSearch': 'GridSearchTuner',
'NetworkMorphism':'NetworkMorphismTuner', 'NetworkMorphism':'NetworkMorphismTuner',
'MetisTuner':'MetisTuner', 'MetisTuner':'MetisTuner',
'GPTuner':'GPTuner',
'Medianstop': 'MedianstopAssessor', 'Medianstop': 'MedianstopAssessor',
'Curvefitting': 'CurvefittingAssessor' 'Curvefitting': 'CurvefittingAssessor'
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np
import unittest import unittest
from .curvefitting_assessor import CurvefittingAssessor from .curvefitting_assessor import CurvefittingAssessor
from .model_factory import CurveModel
from nni.assessor import AssessResult from nni.assessor import AssessResult
class TestCurveFittingAssessor(unittest.TestCase): class TestCurveFittingAssessor(unittest.TestCase):
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
gp_tuner.py
'''
import warnings
import logging
import numpy as np
from sklearn.gaussian_process.kernels import Matern
from sklearn.gaussian_process import GaussianProcessRegressor
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from .target_space import TargetSpace
from .util import UtilityFunction, acq_max
logger = logging.getLogger("GP_Tuner_AutoML")
class GPTuner(Tuner):
'''
GPTuner
'''
def __init__(self, optimize_mode="maximize", utility='ei', kappa=5, xi=0, nu=2.5, alpha=1e-6, cold_start_num=10,
selection_num_warm_up=100000, selection_num_starting_points=250):
self.optimize_mode = OptimizeMode(optimize_mode)
# utility function related
self.utility = utility
self.kappa = kappa
self.xi = xi
# target space
self._space = None
self._random_state = np.random.RandomState()
# nu, alpha are GPR related params
self._gp = GaussianProcessRegressor(
kernel=Matern(nu=nu),
alpha=alpha,
normalize_y=True,
n_restarts_optimizer=25,
random_state=self._random_state
)
# num of random evaluations before GPR
self._cold_start_num = cold_start_num
# params for acq_max
self._selection_num_warm_up = selection_num_warm_up
self._selection_num_starting_points = selection_num_starting_points
# num of imported data
self.supplement_data_num = 0
def update_search_space(self, search_space):
"""Update the self.bounds and self.types by the search_space.json
Parameters
----------
search_space : dict
"""
self._space = TargetSpace(search_space, self._random_state)
def generate_parameters(self, parameter_id):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
gp will first randomly generate some parameters.
Otherwise, choose the parameters by the Gussian Process Model
Parameters
----------
parameter_id : int
Returns
-------
result : dict
"""
if self._space.len() < self._cold_start_num:
results = self._space.random_sample()
else:
# Sklearn's GP throws a large number of warnings at times, but
# we don't really need to see them here.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self._gp.fit(self._space.params, self._space.target)
util = UtilityFunction(
kind=self.utility, kappa=self.kappa, xi=self.xi)
results = acq_max(
f_acq=util.utility,
gp=self._gp,
y_max=self._space.target.max(),
bounds=self._space.bounds,
space=self._space,
num_warmup=self._selection_num_warm_up,
num_starting_points=self._selection_num_starting_points
)
results = self._space.array_to_params(results)
logger.info("Generate paramageters:\n %s", results)
return results
def receive_trial_result(self, parameter_id, parameters, value):
"""Tuner receive result from trial.
Parameters
----------
parameter_id : int
parameters : dict
value : dict/float
if value is dict, it should have "default" key.
"""
value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
logger.info("Received trial result.")
logger.info("value :%s", value)
logger.info("parameter : %s", parameters)
self._space.register(parameters, value)
def import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %
(_completed_num, len(data)))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info(
"Useless trial data, value is %s, skip this trial data." % _value)
continue
self.supplement_data_num += 1
_parameter_id = '_'.join(
["ImportData", str(self.supplement_data_num)])
self.receive_trial_result(
parameter_id=_parameter_id, parameters=_params, value=_value)
logger.info("Successfully import data to GP tuner.")
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
target_space.py
'''
import numpy as np
import nni.parameter_expressions as parameter_expressions
def _hashable(params):
""" ensure that an point is hashable by a python dict """
return tuple(map(float, params))
class TargetSpace():
"""
Holds the param-space coordinates (X) and target values (Y)
"""
def __init__(self, pbounds, random_state=None):
"""
Parameters
----------
pbounds : dict
Dictionary with parameters names as keys and a tuple with minimum
and maximum values.
random_state : int, RandomState, or None
optionally specify a seed for a random number generator
"""
self.random_state = random_state
# Get the name of the parameters
self._keys = sorted(pbounds)
# Create an array with parameters bounds
self._bounds = np.array(
[item[1] for item in sorted(pbounds.items(), key=lambda x: x[0])]
)
# preallocated memory for X and Y points
self._params = np.empty(shape=(0, self.dim))
self._target = np.empty(shape=(0))
# keep track of unique points we have seen so far
self._cache = {}
def __contains__(self, params):
'''
check if a parameter is already registered
'''
return _hashable(params) in self._cache
def len(self):
'''
length of registered params and targets
'''
assert len(self._params) == len(self._target)
return len(self._target)
@property
def params(self):
'''
params: numpy array
'''
return self._params
@property
def target(self):
'''
target: numpy array
'''
return self._target
@property
def dim(self):
'''
dim: int
length of keys
'''
return len(self._keys)
@property
def keys(self):
'''
keys: numpy array
'''
return self._keys
@property
def bounds(self):
'''bounds'''
return self._bounds
def params_to_array(self, params):
''' dict to array '''
try:
assert set(params) == set(self.keys)
except AssertionError:
raise ValueError(
"Parameters' keys ({}) do ".format(sorted(params)) +
"not match the expected set of keys ({}).".format(self.keys)
)
return np.asarray([params[key] for key in self.keys])
def array_to_params(self, x):
'''
array to dict
maintain int type if the paramters is defined as int in search_space.json
'''
try:
assert len(x) == len(self.keys)
except AssertionError:
raise ValueError(
"Size of array ({}) is different than the ".format(len(x)) +
"expected number of parameters ({}).".format(self.dim())
)
params = {}
for i, _bound in enumerate(self._bounds):
if _bound['_type'] == 'choice' and all(isinstance(val, int) for val in _bound['_value']):
params.update({self.keys[i]: int(x[i])})
elif _bound['_type'] in ['randint']:
params.update({self.keys[i]: int(x[i])})
else:
params.update({self.keys[i]: x[i]})
return params
def register(self, params, target):
"""
Append a point and its target value to the known data.
Parameters
----------
x : dict
y : float
target function value
"""
x = self.params_to_array(params)
if x in self:
#raise KeyError('Data point {} is not unique'.format(x))
print('Data point {} is not unique'.format(x))
# Insert data into unique dictionary
self._cache[_hashable(x.ravel())] = target
self._params = np.concatenate([self._params, x.reshape(1, -1)])
self._target = np.concatenate([self._target, [target]])
def random_sample(self):
"""
Creates a random point within the bounds of the space.
"""
params = np.empty(self.dim)
for col, _bound in enumerate(self._bounds):
if _bound['_type'] == 'choice':
params[col] = parameter_expressions.choice(
_bound['_value'], self.random_state)
elif _bound['_type'] == 'randint':
params[col] = self.random_state.randint(
_bound['_value'][0], _bound['_value'][1], size=1)
elif _bound['_type'] == 'uniform':
params[col] = parameter_expressions.uniform(
_bound['_value'][0], _bound['_value'][1], self.random_state)
elif _bound['_type'] == 'quniform':
params[col] = parameter_expressions.quniform(
_bound['_value'][0], _bound['_value'][1], _bound['_value'][2], self.random_state)
elif _bound['_type'] == 'loguniform':
params[col] = parameter_expressions.loguniform(
_bound['_value'][0], _bound['_value'][1], self.random_state)
elif _bound['_type'] == 'qloguniform':
params[col] = parameter_expressions.qloguniform(
_bound['_value'][0], _bound['_value'][1], _bound['_value'][2], self.random_state)
return params
def max(self):
"""Get maximum target value found and corresponding parametes."""
try:
res = {
'target': self.target.max(),
'params': dict(
zip(self.keys, self.params[self.target.argmax()])
)
}
except ValueError:
res = {}
return res
def res(self):
"""Get all target values found and corresponding parametes."""
params = [dict(zip(self.keys, p)) for p in self.params]
return [
{"target": target, "params": param}
for target, param in zip(self.target, params)
]
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
gp_tuner.py
'''
import warnings
import numpy as np
from scipy.stats import norm
from scipy.optimize import minimize
def _match_val_type(vals, bounds):
'''
Update values in the array, to match their corresponding type
'''
vals_new = []
for i, bound in enumerate(bounds):
_type = bound['_type']
if _type == "choice":
# Find the closest integer in the array, vals_bounds
vals_new.append(
min(bound['_value'], key=lambda x: abs(x - vals[i])))
elif _type in ['quniform', 'randint']:
vals_new.append(np.around(vals[i]))
else:
vals_new.append(vals[i])
return vals_new
def acq_max(f_acq, gp, y_max, bounds, space, num_warmup, num_starting_points):
"""
A function to find the maximum of the acquisition function
It uses a combination of random sampling (cheap) and the 'L-BFGS-B'
optimization method. First by sampling `n_warmup` (1e5) points at random,
and then running L-BFGS-B from `n_iter` (250) random starting points.
Parameters
----------
:param f_acq:
The acquisition function object that return its point-wise value.
:param gp:
A gaussian process fitted to the relevant data.
:param y_max:
The current maximum known value of the target function.
:param bounds:
The variables bounds to limit the search of the acq max.
:param num_warmup:
number of times to randomly sample the aquisition function
:param num_starting_points:
number of times to run scipy.minimize
Returns
-------
:return: x_max, The arg max of the acquisition function.
"""
# Warm up with random points
x_tries = [space.random_sample()
for _ in range(int(num_warmup))]
ys = f_acq(x_tries, gp=gp, y_max=y_max)
x_max = x_tries[ys.argmax()]
max_acq = ys.max()
# Explore the parameter space more throughly
x_seeds = [space.random_sample() for _ in range(int(num_starting_points))]
bounds_minmax = np.array(
[[bound['_value'][0], bound['_value'][-1]] for bound in bounds])
for x_try in x_seeds:
# Find the minimum of minus the acquisition function
res = minimize(lambda x: -f_acq(x.reshape(1, -1), gp=gp, y_max=y_max),
x_try.reshape(1, -1),
bounds=bounds_minmax,
method="L-BFGS-B")
# See if success
if not res.success:
continue
# Store it if better than previous minimum(maximum).
if max_acq is None or -res.fun[0] >= max_acq:
x_max = _match_val_type(res.x, bounds)
max_acq = -res.fun[0]
# Clip output to make sure it lies within the bounds. Due to floating
# point technicalities this is not always the case.
return np.clip(x_max, bounds_minmax[:, 0], bounds_minmax[:, 1])
class UtilityFunction():
"""
An object to compute the acquisition functions.
"""
def __init__(self, kind, kappa, xi):
"""
If UCB is to be used, a constant kappa is needed.
"""
self.kappa = kappa
self.xi = xi
if kind not in ['ucb', 'ei', 'poi']:
err = "The utility function " \
"{} has not been implemented, " \
"please choose one of ucb, ei, or poi.".format(kind)
raise NotImplementedError(err)
self.kind = kind
def utility(self, x, gp, y_max):
'''return utility function'''
if self.kind == 'ucb':
return self._ucb(x, gp, self.kappa)
if self.kind == 'ei':
return self._ei(x, gp, y_max, self.xi)
if self.kind == 'poi':
return self._poi(x, gp, y_max, self.xi)
return None
@staticmethod
def _ucb(x, gp, kappa):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mean, std = gp.predict(x, return_std=True)
return mean + kappa * std
@staticmethod
def _ei(x, gp, y_max, xi):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mean, std = gp.predict(x, return_std=True)
z = (mean - y_max - xi)/std
return (mean - y_max - xi) * norm.cdf(z) + std * norm.pdf(z)
@staticmethod
def _poi(x, gp, y_max, xi):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mean, std = gp.predict(x, return_std=True)
z = (mean - y_max - xi)/std
return norm.cdf(z)
...@@ -30,8 +30,8 @@ import json_tricks ...@@ -30,8 +30,8 @@ import json_tricks
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger from nni.common import init_logger, multi_phase_enabled
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, randint_to_quniform from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward, randint_to_quniform
import nni.parameter_expressions as parameter_expressions import nni.parameter_expressions as parameter_expressions
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -144,7 +144,7 @@ class Bracket(): ...@@ -144,7 +144,7 @@ class Bracket():
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ] self.num_configs_to_run = [] # [ n, n, n, ... ]
self.num_finished_configs = [] # [ n, n, n, ... ] self.num_finished_configs = [] # [ n, n, n, ... ]
self.optimize_mode = optimize_mode self.optimize_mode = OptimizeMode(optimize_mode)
self.no_more_trial = False self.no_more_trial = False
def is_completed(self): def is_completed(self):
...@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
""" """
def __init__(self, R, eta=3, optimize_mode='maximize'): def __init__(self, R=60, eta=3, optimize_mode='maximize'):
"""B = (s_max + 1)R""" """B = (s_max + 1)R"""
super(Hyperband, self).__init__() super(Hyperband, self).__init__()
self.R = R # pylint: disable=invalid-name self.R = R # pylint: disable=invalid-name
...@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase): ...@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
# In this case, tuner increases self.credit to issue a trial config sometime later. # In this case, tuner increases self.credit to issue a trial config sometime later.
self.credit = 0 self.credit = 0
def load_checkpoint(self): # record the latest parameter_id of the trial job trial_job_id.
pass # if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def save_checkpoint(self): self.job_id_para_id_map = dict()
pass
def handle_initialize(self, data): def handle_initialize(self, data):
"""data is search space """data is search space
...@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase): ...@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs number of trial jobs
""" """
for _ in range(data): for _ in range(data):
self._request_one_trial_job() ret = self._get_one_trial_job()
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
def _request_one_trial_job(self): def _get_one_trial_job(self):
"""get one trial job, i.e., one hyperparameter configuration.""" """get one trial job, i.e., one hyperparameter configuration."""
if not self.generated_hyper_configs: if not self.generated_hyper_configs:
if self.curr_s < 0: if self.curr_s < 0:
...@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase): ...@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': params[1] 'parameters': params[1]
} }
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) return ret
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""data: JSON object, which is search space """data: JSON object, which is search space
...@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase): ...@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
randint_to_quniform(self.searchspace_json) randint_to_quniform(self.searchspace_json)
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
def _handle_trial_end(self, parameter_id):
"""
Parameters
----------
parameter_id: parameter id of the finished config
"""
bracket_id, i, _ = parameter_id.split('_')
hyper_configs = self.brackets[int(bracket_id)].inform_trial_end(int(i))
if hyper_configs is not None:
_logger.debug('bracket %s next round %s, hyper_configs: %s', bracket_id, i, hyper_configs)
self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
Parameters Parameters
...@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase): ...@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = json_tricks.loads(data['hyper_params'])
bracket_id, i, _ = hyper_params['parameter_id'].split('_') self._handle_trial_end(hyper_params['parameter_id'])
hyper_configs = self.brackets[int(bracket_id)].inform_trial_end(int(i)) if data['trial_job_id'] in self.job_id_para_id_map:
if hyper_configs is not None: del self.job_id_para_id_map[data['trial_job_id']]
_logger.debug('bracket %s next round %s, hyper_configs: %s', bracket_id, i, hyper_configs)
self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs
for _ in range(self.credit):
if not self.generated_hyper_configs:
break
params = self.generated_hyper_configs.pop()
ret = {
'parameter_id': params[0],
'parameter_source': 'algorithm',
'parameters': params[1]
}
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
self.credit -= 1
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
...@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase): ...@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
value = extract_scalar_reward(data['value']) if data['type'] == MetricType.REQUEST_PARAMETER:
bracket_id, i, _ = data['parameter_id'].split('_') assert multi_phase_enabled()
bracket_id = int(bracket_id) assert data['trial_job_id'] is not None
if data['type'] == 'FINAL': assert data['parameter_index'] is not None
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric assert data['trial_job_id'] in self.job_id_para_id_map
# and PERIODICAL metric are independent, thus, not comparable. self._handle_trial_end(self.job_id_para_id_map[data['trial_job_id']])
self.brackets[bracket_id].set_config_perf(int(i), data['parameter_id'], sys.maxsize, value) ret = self._get_one_trial_job()
self.completed_hyper_configs.append(data) if data['trial_job_id'] is not None:
elif data['type'] == 'PERIODICAL': ret['trial_job_id'] = data['trial_job_id']
self.brackets[bracket_id].set_config_perf(int(i), data['parameter_id'], data['sequence'], value) if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_')
bracket_id = int(bracket_id)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if data['trial_job_id'] in self.job_id_para_id_map:
assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id']
else:
self.job_id_para_id_map[data['trial_job_id']] = data['parameter_id']
if data['type'] == MetricType.FINAL:
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
# and PERIODICAL metric are independent, thus, not comparable.
self.brackets[bracket_id].set_config_perf(int(i), data['parameter_id'], sys.maxsize, value)
self.completed_hyper_configs.append(data)
elif data['type'] == MetricType.PERIODICAL:
self.brackets[bracket_id].set_config_perf(int(i), data['parameter_id'], data['sequence'], value)
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass pass
......
...@@ -49,15 +49,16 @@ def selection_r(x_bounds, ...@@ -49,15 +49,16 @@ def selection_r(x_bounds,
num_starting_points=100, num_starting_points=100,
minimize_constraints_fun=None): minimize_constraints_fun=None):
''' '''
Call selection Select using different types.
''' '''
minimize_starting_points = [lib_data.rand(x_bounds, x_types)\ minimize_starting_points = clusteringmodel_gmm_good.sample(n_samples=num_starting_points)
for i in range(0, num_starting_points)]
outputs = selection(x_bounds, x_types, outputs = selection(x_bounds, x_types,
clusteringmodel_gmm_good, clusteringmodel_gmm_good,
clusteringmodel_gmm_bad, clusteringmodel_gmm_bad,
minimize_starting_points, minimize_starting_points[0],
minimize_constraints_fun) minimize_constraints_fun)
return outputs return outputs
def selection(x_bounds, def selection(x_bounds,
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
import copy import copy
import logging import logging
import numpy as np
import os import os
import random import random
import statistics import statistics
import sys import sys
import warnings
from enum import Enum, unique from enum import Enum, unique
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
import numpy as np
import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation
import nni.metis_tuner.lib_data as lib_data import nni.metis_tuner.lib_data as lib_data
import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model
...@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward ...@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
logger = logging.getLogger("Metis_Tuner_AutoML") logger = logging.getLogger("Metis_Tuner_AutoML")
NONE_TYPE = '' NONE_TYPE = ''
CONSTRAINT_LOWERBOUND = None CONSTRAINT_LOWERBOUND = None
CONSTRAINT_UPPERBOUND = None CONSTRAINT_UPPERBOUND = None
...@@ -93,7 +91,7 @@ class MetisTuner(Tuner): ...@@ -93,7 +91,7 @@ class MetisTuner(Tuner):
self.space = None self.space = None
self.no_resampling = no_resampling self.no_resampling = no_resampling
self.no_candidates = no_candidates self.no_candidates = no_candidates
self.optimize_mode = optimize_mode self.optimize_mode = OptimizeMode(optimize_mode)
self.key_order = [] self.key_order = []
self.cold_start_num = cold_start_num self.cold_start_num = cold_start_num
self.selection_num_starting_points = selection_num_starting_points self.selection_num_starting_points = selection_num_starting_points
...@@ -254,6 +252,9 @@ class MetisTuner(Tuner): ...@@ -254,6 +252,9 @@ class MetisTuner(Tuner):
threshold_samplessize_resampling=50, no_candidates=False, threshold_samplessize_resampling=50, no_candidates=False,
minimize_starting_points=None, minimize_constraints_fun=None): minimize_starting_points=None, minimize_constraints_fun=None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
next_candidate = None next_candidate = None
candidates = [] candidates = []
samples_size_all = sum([len(i) for i in samples_y]) samples_size_all = sum([len(i) for i in samples_y])
...@@ -271,13 +272,12 @@ class MetisTuner(Tuner): ...@@ -271,13 +272,12 @@ class MetisTuner(Tuner):
minimize_constraints_fun=minimize_constraints_fun) minimize_constraints_fun=minimize_constraints_fun)
if not lm_current: if not lm_current:
return None return None
logger.info({'hyperparameter': lm_current['hyperparameter'],
if no_candidates is False:
candidates.append({'hyperparameter': lm_current['hyperparameter'],
'expected_mu': lm_current['expected_mu'], 'expected_mu': lm_current['expected_mu'],
'expected_sigma': lm_current['expected_sigma'], 'expected_sigma': lm_current['expected_sigma'],
'reason': "exploitation_gp"}) 'reason': "exploitation_gp"})
if no_candidates is False:
# ===== STEP 2: Get recommended configurations for exploration ===== # ===== STEP 2: Get recommended configurations for exploration =====
results_exploration = gp_selection.selection( results_exploration = gp_selection.selection(
"lc", "lc",
...@@ -290,34 +290,48 @@ class MetisTuner(Tuner): ...@@ -290,34 +290,48 @@ class MetisTuner(Tuner):
if results_exploration is not None: if results_exploration is not None:
if _num_past_samples(results_exploration['hyperparameter'], samples_x, samples_y) == 0: if _num_past_samples(results_exploration['hyperparameter'], samples_x, samples_y) == 0:
candidates.append({'hyperparameter': results_exploration['hyperparameter'], temp_candidate = {'hyperparameter': results_exploration['hyperparameter'],
'expected_mu': results_exploration['expected_mu'], 'expected_mu': results_exploration['expected_mu'],
'expected_sigma': results_exploration['expected_sigma'], 'expected_sigma': results_exploration['expected_sigma'],
'reason': "exploration"}) 'reason': "exploration"}
candidates.append(temp_candidate)
logger.info("DEBUG: 1 exploration candidate selected\n") logger.info("DEBUG: 1 exploration candidate selected\n")
logger.info(temp_candidate)
else: else:
logger.info("DEBUG: No suitable exploration candidates were") logger.info("DEBUG: No suitable exploration candidates were")
# ===== STEP 3: Get recommended configurations for exploitation ===== # ===== STEP 3: Get recommended configurations for exploitation =====
if samples_size_all >= threshold_samplessize_exploitation: if samples_size_all >= threshold_samplessize_exploitation:
print("Getting candidates for exploitation...\n") logger.info("Getting candidates for exploitation...\n")
try: try:
gmm = gmm_create_model.create_model(samples_x, samples_y_aggregation) gmm = gmm_create_model.create_model(samples_x, samples_y_aggregation)
results_exploitation = gmm_selection.selection(
x_bounds, if ("discrete_int" in x_types) or ("range_int" in x_types):
x_types, results_exploitation = gmm_selection.selection(x_bounds, x_types,
gmm['clusteringmodel_good'], gmm['clusteringmodel_good'],
gmm['clusteringmodel_bad'], gmm['clusteringmodel_bad'],
minimize_starting_points, minimize_starting_points,
minimize_constraints_fun=minimize_constraints_fun) minimize_constraints_fun=minimize_constraints_fun)
else:
# If all parameters are of "range_continuous", let's use GMM to generate random starting points
results_exploitation = gmm_selection.selection_r(x_bounds, x_types,
gmm['clusteringmodel_good'],
gmm['clusteringmodel_bad'],
num_starting_points=self.selection_num_starting_points,
minimize_constraints_fun=minimize_constraints_fun)
if results_exploitation is not None: if results_exploitation is not None:
if _num_past_samples(results_exploitation['hyperparameter'], samples_x, samples_y) == 0: if _num_past_samples(results_exploitation['hyperparameter'], samples_x, samples_y) == 0:
candidates.append({'hyperparameter': results_exploitation['hyperparameter'],\ temp_expected_mu, temp_expected_sigma = gp_prediction.predict(results_exploitation['hyperparameter'], gp_model['model'])
'expected_mu': results_exploitation['expected_mu'],\ temp_candidate = {'hyperparameter': results_exploitation['hyperparameter'],
'expected_sigma': results_exploitation['expected_sigma'],\ 'expected_mu': temp_expected_mu,
'reason': "exploitation_gmm"}) 'expected_sigma': temp_expected_sigma,
'reason': "exploitation_gmm"}
candidates.append(temp_candidate)
logger.info("DEBUG: 1 exploitation_gmm candidate selected\n") logger.info("DEBUG: 1 exploitation_gmm candidate selected\n")
logger.info(temp_candidate)
else: else:
logger.info("DEBUG: No suitable exploitation_gmm candidates were found\n") logger.info("DEBUG: No suitable exploitation_gmm candidates were found\n")
...@@ -338,11 +352,13 @@ class MetisTuner(Tuner): ...@@ -338,11 +352,13 @@ class MetisTuner(Tuner):
if results_outliers is not None: if results_outliers is not None:
for results_outlier in results_outliers: for results_outlier in results_outliers:
if _num_past_samples(samples_x[results_outlier['samples_idx']], samples_x, samples_y) < max_resampling_per_x: if _num_past_samples(samples_x[results_outlier['samples_idx']], samples_x, samples_y) < max_resampling_per_x:
candidates.append({'hyperparameter': samples_x[results_outlier['samples_idx']],\ temp_candidate = {'hyperparameter': samples_x[results_outlier['samples_idx']],\
'expected_mu': results_outlier['expected_mu'],\ 'expected_mu': results_outlier['expected_mu'],\
'expected_sigma': results_outlier['expected_sigma'],\ 'expected_sigma': results_outlier['expected_sigma'],\
'reason': "resampling"}) 'reason': "resampling"}
candidates.append(temp_candidate)
logger.info("DEBUG: %d re-sampling candidates selected\n") logger.info("DEBUG: %d re-sampling candidates selected\n")
logger.info(temp_candidate)
else: else:
logger.info("DEBUG: No suitable resampling candidates were found\n") logger.info("DEBUG: No suitable resampling candidates were found\n")
......
...@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase ...@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .utils import MetricType
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if data['type'] == 'FINAL': if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
elif data['type'] == 'PERIODICAL': elif data['type'] == MetricType.PERIODICAL:
if self.assessor is not None: if self.assessor is not None:
self._handle_intermediate_metric_data(data) self._handle_intermediate_metric_data(data)
elif data['type'] == 'REQUEST_PARAMETER': elif data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None assert data['parameter_index'] is not None
...@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
""" """
if data['type'] != 'PERIODICAL': if data['type'] != MetricType.PERIODICAL:
return return
if self.assessor is None: if self.assessor is None:
return return
...@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
trial is early stopped. trial is early stopped.
""" """
_logger.debug('Early stop notify tuner data: [%s]', data) _logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL' data['type'] = MetricType.FINAL
if multi_thread_enabled(): if multi_thread_enabled():
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment