Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d48ad027
Unverified
Commit
d48ad027
authored
Jun 20, 2019
by
SparkSnail
Committed by
GitHub
Jun 20, 2019
Browse files
Merge pull request #184 from microsoft/master
merge master
parents
9352cc88
22993e5d
Changes
187
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
95 additions
and
81 deletions
+95
-81
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
...nager/training_service/remote_machine/sshClientUtility.ts
+23
-15
src/nni_manager/training_service/test/hdfsClientUtility.test.ts
...i_manager/training_service/test/hdfsClientUtility.test.ts
+4
-4
src/nni_manager/training_service/test/kubeflowTrainingService.test.ts
...ger/training_service/test/kubeflowTrainingService.test.ts
+2
-2
src/nni_manager/training_service/test/localTrainingService.test.ts
...anager/training_service/test/localTrainingService.test.ts
+2
-2
src/nni_manager/training_service/test/paiTrainingService.test.ts
..._manager/training_service/test/paiTrainingService.test.ts
+1
-1
src/nni_manager/tslint.json
src/nni_manager/tslint.json
+4
-1
src/nni_manager/types/child-process-promise/index.d.ts
src/nni_manager/types/child-process-promise/index.d.ts
+1
-1
src/nni_manager/types/webhdfs/index.d.ts
src/nni_manager/types/webhdfs/index.d.ts
+3
-0
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+1
-1
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+11
-11
src/sdk/pynni/nni/bohb_advisor/config_generator.py
src/sdk/pynni/nni/bohb_advisor/config_generator.py
+1
-1
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+1
-1
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
.../pynni/nni/curvefitting_assessor/curvefitting_assessor.py
+2
-2
src/sdk/pynni/nni/curvefitting_assessor/curvefunctions.py
src/sdk/pynni/nni/curvefitting_assessor/curvefunctions.py
+3
-3
src/sdk/pynni/nni/curvefitting_assessor/model_factory.py
src/sdk/pynni/nni/curvefitting_assessor/model_factory.py
+13
-13
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
+2
-2
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+11
-11
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
+5
-5
src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py
src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py
+1
-1
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
.../pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
+4
-4
No files found.
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
View file @
d48ad027
...
@@ -21,16 +21,16 @@
...
@@ -21,16 +21,16 @@
import
*
as
assert
from
'
assert
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
path
from
'
path
'
;
import
*
as
os
from
'
os
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
{
Client
,
ClientChannel
,
SFTPWrapper
}
from
'
ssh2
'
;
import
{
Client
,
ClientChannel
,
SFTPWrapper
}
from
'
ssh2
'
;
import
*
as
stream
from
'
stream
'
;
import
*
as
stream
from
'
stream
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
NNIError
,
NNIErrorNames
}
from
'
../../common/errors
'
;
import
{
NNIError
,
NNIErrorNames
}
from
'
../../common/errors
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
uniqueString
,
getRemoteTmpDir
,
unixPathJoin
}
from
'
../../common/utils
'
;
import
{
getRemoteTmpDir
,
uniqueString
,
unixPathJoin
}
from
'
../../common/utils
'
;
import
{
RemoteCommandResult
}
from
'
./remoteMachineData
'
;
import
{
execRemove
,
tarAdd
}
from
'
../common/util
'
;
import
{
execRemove
,
tarAdd
}
from
'
../common/util
'
;
import
{
RemoteCommandResult
}
from
'
./remoteMachineData
'
;
/**
/**
*
*
...
@@ -44,7 +44,8 @@ export namespace SSHClientUtility {
...
@@ -44,7 +44,8 @@ export namespace SSHClientUtility {
* @param remoteDirectory remote directory
* @param remoteDirectory remote directory
* @param sshClient SSH client
* @param sshClient SSH client
*/
*/
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
)
:
Promise
<
void
>
{
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
)
:
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
...
@@ -75,7 +76,7 @@ export namespace SSHClientUtility {
...
@@ -75,7 +76,7 @@ export namespace SSHClientUtility {
assert
(
sshClient
!==
undefined
);
assert
(
sshClient
!==
undefined
);
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`copyFileToRemote:
${
err
.
message
}
,
${
localFilePath
}
,
${
remoteFilePath
}
`
);
log
.
error
(
`copyFileToRemote:
${
err
.
message
}
,
${
localFilePath
}
,
${
remoteFilePath
}
`
);
deferred
.
reject
(
err
);
deferred
.
reject
(
err
);
...
@@ -84,7 +85,7 @@ export namespace SSHClientUtility {
...
@@ -84,7 +85,7 @@ export namespace SSHClientUtility {
assert
(
sftp
!==
undefined
);
assert
(
sftp
!==
undefined
);
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
end
();
sftp
.
end
();
if
(
fastPutErr
)
{
if
(
fastPutErr
!==
undefined
&&
fastPutErr
!==
null
)
{
deferred
.
reject
(
fastPutErr
);
deferred
.
reject
(
fastPutErr
);
}
else
{
}
else
{
deferred
.
resolve
(
true
);
deferred
.
resolve
(
true
);
...
@@ -100,6 +101,7 @@ export namespace SSHClientUtility {
...
@@ -100,6 +101,7 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely
* @param command the command to execute remotely
* @param client SSH Client
* @param client SSH Client
*/
*/
// tslint:disable:no-unsafe-any no-any
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
const
log
:
Logger
=
getLogger
();
const
log
:
Logger
=
getLogger
();
log
.
debug
(
`remoteExeCommand: command: [
${
command
}
]`
);
log
.
debug
(
`remoteExeCommand: command: [
${
command
}
]`
);
...
@@ -109,7 +111,7 @@ export namespace SSHClientUtility {
...
@@ -109,7 +111,7 @@ export namespace SSHClientUtility {
let
exitCode
:
number
;
let
exitCode
:
number
;
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
if
(
err
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`remoteExeCommand:
${
err
.
message
}
`
);
log
.
error
(
`remoteExeCommand:
${
err
.
message
}
`
);
deferred
.
reject
(
err
);
deferred
.
reject
(
err
);
...
@@ -117,13 +119,14 @@ export namespace SSHClientUtility {
...
@@ -117,13 +119,14 @@ export namespace SSHClientUtility {
}
}
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
if
(
dataStderr
)
{
if
(
dataStderr
!==
undefined
&&
dataStderr
!==
null
)
{
stderr
+=
data
.
toString
();
stderr
+=
data
.
toString
();
}
else
{
}
else
{
stdout
+=
data
.
toString
();
stdout
+=
data
.
toString
();
}
}
}).
on
(
'
exit
'
,
(
code
,
signal
)
=>
{
})
exitCode
=
code
as
number
;
.
on
(
'
exit
'
,
(
code
:
any
,
signal
:
any
)
=>
{
exitCode
=
<
number
>
code
;
deferred
.
resolve
({
deferred
.
resolve
({
stdout
:
stdout
,
stdout
:
stdout
,
stderr
:
stderr
,
stderr
:
stderr
,
...
@@ -138,8 +141,9 @@ export namespace SSHClientUtility {
...
@@ -138,8 +141,9 @@ export namespace SSHClientUtility {
export
function
getRemoteFileContent
(
filePath
:
string
,
sshClient
:
Client
):
Promise
<
string
>
{
export
function
getRemoteFileContent
(
filePath
:
string
,
sshClient
:
Client
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
getLogger
().
error
(
`getRemoteFileContent:
${
err
.
message
}
`
);
getLogger
()
.
error
(
`getRemoteFileContent:
${
err
.
message
}
`
);
deferred
.
reject
(
new
Error
(
`SFTP error:
${
err
.
message
}
`
));
deferred
.
reject
(
new
Error
(
`SFTP error:
${
err
.
message
}
`
));
return
;
return
;
...
@@ -150,16 +154,19 @@ export namespace SSHClientUtility {
...
@@ -150,16 +154,19 @@ export namespace SSHClientUtility {
let
dataBuffer
:
string
=
''
;
let
dataBuffer
:
string
=
''
;
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
dataBuffer
+=
data
;
dataBuffer
+=
data
;
}).
on
(
'
error
'
,
(
streamErr
:
Error
)
=>
{
})
.
on
(
'
error
'
,
(
streamErr
:
Error
)
=>
{
sftp
.
end
();
sftp
.
end
();
deferred
.
reject
(
new
NNIError
(
NNIErrorNames
.
NOT_FOUND
,
streamErr
.
message
));
deferred
.
reject
(
new
NNIError
(
NNIErrorNames
.
NOT_FOUND
,
streamErr
.
message
));
}).
on
(
'
end
'
,
()
=>
{
})
.
on
(
'
end
'
,
()
=>
{
// sftp connection need to be released manually once operation is done
// sftp connection need to be released manually once operation is done
sftp
.
end
();
sftp
.
end
();
deferred
.
resolve
(
dataBuffer
);
deferred
.
resolve
(
dataBuffer
);
});
});
}
catch
(
error
)
{
}
catch
(
error
)
{
getLogger
().
error
(
`getRemoteFileContent:
${
error
.
message
}
`
);
getLogger
()
.
error
(
`getRemoteFileContent:
${
error
.
message
}
`
);
sftp
.
end
();
sftp
.
end
();
deferred
.
reject
(
new
Error
(
`SFTP error:
${
error
.
message
}
`
));
deferred
.
reject
(
new
Error
(
`SFTP error:
${
error
.
message
}
`
));
}
}
...
@@ -167,4 +174,5 @@ export namespace SSHClientUtility {
...
@@ -167,4 +174,5 @@ export namespace SSHClientUtility {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
// tslint:enable:no-unsafe-any no-any
}
}
src/nni_manager/training_service/test/hdfsClientUtility.test.ts
View file @
d48ad027
...
@@ -37,7 +37,7 @@ describe('WebHDFS', function () {
...
@@ -37,7 +37,7 @@ describe('WebHDFS', function () {
{
{
"user": "user1",
"user": "user1",
"port": 50070,
"port": 50070,
"host": "10.0.0.0"
"host": "10.0.0.0"
}
}
*/
*/
let
skip
:
boolean
=
false
;
let
skip
:
boolean
=
false
;
...
@@ -45,7 +45,7 @@ describe('WebHDFS', function () {
...
@@ -45,7 +45,7 @@ describe('WebHDFS', function () {
let
hdfsClient
:
any
;
let
hdfsClient
:
any
;
try
{
try
{
testHDFSInfo
=
JSON
.
parse
(
fs
.
readFileSync
(
'
../../.vscode/hdfsInfo.json
'
,
'
utf8
'
));
testHDFSInfo
=
JSON
.
parse
(
fs
.
readFileSync
(
'
../../.vscode/hdfsInfo.json
'
,
'
utf8
'
));
console
.
log
(
testHDFSInfo
);
console
.
log
(
testHDFSInfo
);
hdfsClient
=
WebHDFS
.
createClient
({
hdfsClient
=
WebHDFS
.
createClient
({
user
:
testHDFSInfo
.
user
,
user
:
testHDFSInfo
.
user
,
port
:
testHDFSInfo
.
port
,
port
:
testHDFSInfo
.
port
,
...
@@ -120,7 +120,7 @@ describe('WebHDFS', function () {
...
@@ -120,7 +120,7 @@ describe('WebHDFS', function () {
chai
.
expect
(
actualFileData
).
to
.
be
.
equals
(
testFileData
);
chai
.
expect
(
actualFileData
).
to
.
be
.
equals
(
testFileData
);
const
testHDFSDirPath
:
string
=
path
.
join
(
'
/nni_unittest_
'
+
uniqueString
(
6
)
+
'
_dir
'
);
const
testHDFSDirPath
:
string
=
path
.
join
(
'
/nni_unittest_
'
+
uniqueString
(
6
)
+
'
_dir
'
);
await
HDFSClientUtility
.
copyDirectoryToHdfs
(
tmpLocalDirectoryPath
,
testHDFSDirPath
,
hdfsClient
);
await
HDFSClientUtility
.
copyDirectoryToHdfs
(
tmpLocalDirectoryPath
,
testHDFSDirPath
,
hdfsClient
);
const
files
:
any
[]
=
await
HDFSClientUtility
.
readdir
(
testHDFSDirPath
,
hdfsClient
);
const
files
:
any
[]
=
await
HDFSClientUtility
.
readdir
(
testHDFSDirPath
,
hdfsClient
);
...
@@ -133,7 +133,7 @@ describe('WebHDFS', function () {
...
@@ -133,7 +133,7 @@ describe('WebHDFS', function () {
// Cleanup
// Cleanup
rmdir
(
tmpLocalDirectoryPath
);
rmdir
(
tmpLocalDirectoryPath
);
let
deleteRestult
:
boolean
=
await
HDFSClientUtility
.
deletePath
(
testHDFSFilePath
,
hdfsClient
);
let
deleteRestult
:
boolean
=
await
HDFSClientUtility
.
deletePath
(
testHDFSFilePath
,
hdfsClient
);
chai
.
expect
(
deleteRestult
).
to
.
be
.
equals
(
true
);
chai
.
expect
(
deleteRestult
).
to
.
be
.
equals
(
true
);
...
...
src/nni_manager/training_service/test/kubeflowTrainingService.test.ts
View file @
d48ad027
...
@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
...
@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
if
(
skip
)
{
if
(
skip
)
{
return
;
return
;
}
}
kubeflowTrainingService
=
component
.
get
(
KubeflowTrainingService
);
kubeflowTrainingService
=
component
.
get
(
KubeflowTrainingService
);
});
});
afterEach
(()
=>
{
afterEach
(()
=>
{
...
@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => {
...
@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => {
return
;
return
;
}
}
await
kubeflowTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
KUBEFLOW_CLUSTER_CONFIG
,
testKubeflowConfig
),
await
kubeflowTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
KUBEFLOW_CLUSTER_CONFIG
,
testKubeflowConfig
),
await
kubeflowTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
,
testKubeflowTrialConfig
);
await
kubeflowTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
,
testKubeflowTrialConfig
);
});
});
});
});
\ No newline at end of file
src/nni_manager/training_service/test/localTrainingService.test.ts
View file @
d48ad027
...
@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => {
...
@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => {
//trial jobs should be empty, since there are no submitted jobs
//trial jobs should be empty, since there are no submitted jobs
chai
.
expect
(
await
localTrainingService
.
listTrialJobs
()).
to
.
be
.
empty
;
chai
.
expect
(
await
localTrainingService
.
listTrialJobs
()).
to
.
be
.
empty
;
});
});
it
(
'
setClusterMetadata and getClusterMetadata
'
,
async
()
=>
{
it
(
'
setClusterMetadata and getClusterMetadata
'
,
async
()
=>
{
await
localTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
,
trialConfig
);
await
localTrainingService
.
setClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
,
trialConfig
);
localTrainingService
.
getClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
).
then
((
data
)
=>
{
localTrainingService
.
getClusterMetadata
(
TrialConfigMetadataKey
.
TRIAL_CONFIG
).
then
((
data
)
=>
{
...
@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => {
...
@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => {
await
localTrainingService
.
cancelTrialJob
(
jobDetail
.
id
);
await
localTrainingService
.
cancelTrialJob
(
jobDetail
.
id
);
chai
.
expect
(
jobDetail
.
status
).
to
.
be
.
equals
(
'
USER_CANCELED
'
);
chai
.
expect
(
jobDetail
.
status
).
to
.
be
.
equals
(
'
USER_CANCELED
'
);
}).
timeout
(
20000
);
}).
timeout
(
20000
);
it
(
'
Read metrics, Add listener, and remove listener
'
,
async
()
=>
{
it
(
'
Read metrics, Add listener, and remove listener
'
,
async
()
=>
{
// set meta data
// set meta data
const
trialConfig
:
string
=
`{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"
${
localCodeDir
}
\",\"gpuNum\":0}`
const
trialConfig
:
string
=
`{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"
${
localCodeDir
}
\",\"gpuNum\":0}`
...
...
src/nni_manager/training_service/test/paiTrainingService.test.ts
View file @
d48ad027
...
@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => {
...
@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => {
chai
.
expect
(
trialDetail
.
status
).
to
.
be
.
equals
(
'
WAITING
'
);
chai
.
expect
(
trialDetail
.
status
).
to
.
be
.
equals
(
'
WAITING
'
);
}
catch
(
error
)
{
}
catch
(
error
)
{
console
.
log
(
'
Submit job failed:
'
+
error
);
console
.
log
(
'
Submit job failed:
'
+
error
);
chai
.
assert
(
error
)
chai
.
assert
(
error
)
}
}
});
});
});
});
\ No newline at end of file
src/nni_manager/tslint.json
View file @
d48ad027
...
@@ -9,7 +9,10 @@
...
@@ -9,7 +9,10 @@
"no-increment-decrement"
:
false
,
"no-increment-decrement"
:
false
,
"promise-function-async"
:
false
,
"promise-function-async"
:
false
,
"no-console"
:
[
true
,
"log"
],
"no-console"
:
[
true
,
"log"
],
"no-multiline-string"
:
false
"no-multiline-string"
:
false
,
"no-suspicious-comment"
:
false
,
"no-backbone-get-set-outside-model"
:
false
,
"max-classes-per-file"
:
false
},
},
"rulesDirectory"
:
[],
"rulesDirectory"
:
[],
"linterOptions"
:
{
"linterOptions"
:
{
...
...
src/nni_manager/types/child-process-promise/index.d.ts
View file @
d48ad027
...
@@ -7,5 +7,5 @@ declare module 'child-process-promise' {
...
@@ -7,5 +7,5 @@ declare module 'child-process-promise' {
stderr
:
string
,
stderr
:
string
,
message
:
string
message
:
string
}
}
}
}
}
}
\ No newline at end of file
src/nni_manager/types/webhdfs/index.d.ts
0 → 100644
View file @
d48ad027
declare
module
'
webhdfs
'
{
export
function
createClient
(
arg
:
any
):
any
;
}
\ No newline at end of file
src/sdk/pynni/nni/__main__.py
View file @
d48ad027
...
@@ -154,7 +154,7 @@ def main():
...
@@ -154,7 +154,7 @@ def main():
assessor
=
None
assessor
=
None
if
args
.
tuner_class_name
in
ModuleName
:
if
args
.
tuner_class_name
in
ModuleName
:
tuner
=
create_builtin_class_instance
(
tuner
=
create_builtin_class_instance
(
args
.
tuner_class_name
,
args
.
tuner_class_name
,
args
.
tuner_args
)
args
.
tuner_args
)
else
:
else
:
tuner
=
create_customized_class_instance
(
tuner
=
create_customized_class_instance
(
...
...
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
d48ad027
...
@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
...
@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
class
Bracket
():
class
Bracket
():
"""
"""
A bracket in BOHB, all the information of a bracket is managed by
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
an instance of this class.
Parameters
Parameters
...
@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase):
BOHB performs robust and efficient hyperparameter optimization
BOHB performs robust and efficient hyperparameter optimization
at scale by combining the speed of Hyperband searches with the
at scale by combining the speed of Hyperband searches with the
guidance and guarantees of convergence of Bayesian Optimization.
guidance and guarantees of convergence of Bayesian Optimization.
Instead of sampling new configurations at random, BOHB uses
Instead of sampling new configurations at random, BOHB uses
kernel density estimators to select promising candidates.
kernel density estimators to select promising candidates.
Parameters
Parameters
...
@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase):
pass
pass
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
and search space formations
and search space formations
Parameters
Parameters
...
@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase):
If this function is called, Command will be sent by BOHB:
If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{
{
'parameter_id': id of new hyperparameter
'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm'
'parameter_source': 'algorithm'
'parameters': value of new hyperparameter
'parameters': value of new hyperparameter
...
@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase):
...
@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase):
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
]))
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
]))
elif
_type
==
'quniform'
:
elif
_type
==
'quniform'
:
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
q
=
search_space
[
var
][
"_value"
][
2
]))
q
=
search_space
[
var
][
"_value"
][
2
]))
elif
_type
==
'loguniform'
:
elif
_type
==
'loguniform'
:
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
log
=
True
))
log
=
True
))
elif
_type
==
'qloguniform'
:
elif
_type
==
'qloguniform'
:
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
UniformFloatHyperparameter
(
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
var
,
lower
=
search_space
[
var
][
"_value"
][
0
],
upper
=
search_space
[
var
][
"_value"
][
1
],
q
=
search_space
[
var
][
"_value"
][
2
],
log
=
True
))
q
=
search_space
[
var
][
"_value"
][
2
],
log
=
True
))
elif
_type
==
'normal'
:
elif
_type
==
'normal'
:
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
]))
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
]))
elif
_type
==
'qnormal'
:
elif
_type
==
'qnormal'
:
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
q
=
search_space
[
var
][
"_value"
][
3
]))
q
=
search_space
[
var
][
"_value"
][
3
]))
elif
_type
==
'lognormal'
:
elif
_type
==
'lognormal'
:
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
log
=
True
))
log
=
True
))
elif
_type
==
'qlognormal'
:
elif
_type
==
'qlognormal'
:
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
cs
.
add_hyperparameter
(
CSH
.
NormalFloatHyperparameter
(
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
var
,
mu
=
search_space
[
var
][
"_value"
][
1
],
sigma
=
search_space
[
var
][
"_value"
][
2
],
q
=
search_space
[
var
][
"_value"
][
3
],
log
=
True
))
q
=
search_space
[
var
][
"_value"
][
3
],
log
=
True
))
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase):
self
.
brackets
[
s
].
set_config_perf
(
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
self
.
completed_hyper_configs
.
append
(
data
)
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
.
pop
(
_KEY
)
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
# update BO with loss, max_s budget, hyperparameters
...
...
src/sdk/pynni/nni/bohb_advisor/config_generator.py
View file @
d48ad027
...
@@ -117,7 +117,7 @@ class CG_BOHB(object):
...
@@ -117,7 +117,7 @@ class CG_BOHB(object):
seperated by budget. This function sample a configuration from
seperated by budget. This function sample a configuration from
largest budget. Firstly we sample "num_samples" configurations,
largest budget. Firstly we sample "num_samples" configurations,
then prefer one with the largest l(x)/g(x).
then prefer one with the largest l(x)/g(x).
Parameters:
Parameters:
-----------
-----------
info_dict: dict
info_dict: dict
...
...
src/sdk/pynni/nni/common.py
View file @
d48ad027
...
@@ -34,7 +34,7 @@ log_level_map = {
...
@@ -34,7 +34,7 @@ log_level_map = {
}
}
_time_format
=
'%m/%d/%Y, %I:%M:%S %p'
_time_format
=
'%m/%d/%Y, %I:%M:%S %p'
class
_LoggerFileWrapper
(
TextIOBase
):
class
_LoggerFileWrapper
(
TextIOBase
):
def
__init__
(
self
,
logger_file
):
def
__init__
(
self
,
logger_file
):
self
.
file
=
logger_file
self
.
file
=
logger_file
...
...
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
View file @
d48ad027
...
@@ -67,7 +67,7 @@ class CurvefittingAssessor(Assessor):
...
@@ -67,7 +67,7 @@ class CurvefittingAssessor(Assessor):
def
trial_end
(
self
,
trial_job_id
,
success
):
def
trial_end
(
self
,
trial_job_id
,
success
):
"""update the best performance of completed trial job
"""update the best performance of completed trial job
Parameters
Parameters
----------
----------
trial_job_id: int
trial_job_id: int
...
@@ -112,7 +112,7 @@ class CurvefittingAssessor(Assessor):
...
@@ -112,7 +112,7 @@ class CurvefittingAssessor(Assessor):
curr_step
=
len
(
trial_history
)
curr_step
=
len
(
trial_history
)
if
curr_step
<
self
.
start_step
:
if
curr_step
<
self
.
start_step
:
return
AssessResult
.
Good
return
AssessResult
.
Good
if
trial_job_id
in
self
.
last_judgment_num
.
keys
()
and
curr_step
-
self
.
last_judgment_num
[
trial_job_id
]
<
self
.
gap
:
if
trial_job_id
in
self
.
last_judgment_num
.
keys
()
and
curr_step
-
self
.
last_judgment_num
[
trial_job_id
]
<
self
.
gap
:
return
AssessResult
.
Good
return
AssessResult
.
Good
self
.
last_judgment_num
[
trial_job_id
]
=
curr_step
self
.
last_judgment_num
[
trial_job_id
]
=
curr_step
...
...
src/sdk/pynni/nni/curvefitting_assessor/curvefunctions.py
View file @
d48ad027
...
@@ -26,7 +26,7 @@ curve_combination_models = ['vap', 'pow3', 'linear', 'logx_linear', 'dr_hill_zer
...
@@ -26,7 +26,7 @@ curve_combination_models = ['vap', 'pow3', 'linear', 'logx_linear', 'dr_hill_zer
def
vap
(
x
,
a
,
b
,
c
):
def
vap
(
x
,
a
,
b
,
c
):
"""Vapor pressure model
"""Vapor pressure model
Parameters
Parameters
----------
----------
x: int
x: int
...
@@ -109,7 +109,7 @@ model_para_num['logx_linear'] = 2
...
@@ -109,7 +109,7 @@ model_para_num['logx_linear'] = 2
def
dr_hill_zero_background
(
x
,
theta
,
eta
,
kappa
):
def
dr_hill_zero_background
(
x
,
theta
,
eta
,
kappa
):
"""dr hill zero background
"""dr hill zero background
Parameters
Parameters
----------
----------
x: int
x: int
...
@@ -261,7 +261,7 @@ model_para_num['weibull'] = 4
...
@@ -261,7 +261,7 @@ model_para_num['weibull'] = 4
def
janoschek
(
x
,
a
,
beta
,
k
,
delta
):
def
janoschek
(
x
,
a
,
beta
,
k
,
delta
):
"""http://www.pisces-conservation.com/growthhelp/janoschek.htm
"""http://www.pisces-conservation.com/growthhelp/janoschek.htm
Parameters
Parameters
----------
----------
x: int
x: int
...
...
src/sdk/pynni/nni/curvefitting_assessor/model_factory.py
View file @
d48ad027
...
@@ -35,7 +35,7 @@ logger = logging.getLogger('curvefitting_Assessor')
...
@@ -35,7 +35,7 @@ logger = logging.getLogger('curvefitting_Assessor')
class
CurveModel
(
object
):
class
CurveModel
(
object
):
"""Build a Curve Model to predict the performance
"""Build a Curve Model to predict the performance
Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md
Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md
Parameters
Parameters
...
@@ -53,7 +53,7 @@ class CurveModel(object):
...
@@ -53,7 +53,7 @@ class CurveModel(object):
def
fit_theta
(
self
):
def
fit_theta
(
self
):
"""use least squares to fit all default curves parameter seperately
"""use least squares to fit all default curves parameter seperately
Returns
Returns
-------
-------
None
None
...
@@ -87,7 +87,7 @@ class CurveModel(object):
...
@@ -87,7 +87,7 @@ class CurveModel(object):
def
filter_curve
(
self
):
def
filter_curve
(
self
):
"""filter the poor performing curve
"""filter the poor performing curve
Returns
Returns
-------
-------
None
None
...
@@ -117,7 +117,7 @@ class CurveModel(object):
...
@@ -117,7 +117,7 @@ class CurveModel(object):
def
predict_y
(
self
,
model
,
pos
):
def
predict_y
(
self
,
model
,
pos
):
"""return the predict y of 'model' when epoch = pos
"""return the predict y of 'model' when epoch = pos
Parameters
Parameters
----------
----------
model: string
model: string
...
@@ -162,7 +162,7 @@ class CurveModel(object):
...
@@ -162,7 +162,7 @@ class CurveModel(object):
def
normalize_weight
(
self
,
samples
):
def
normalize_weight
(
self
,
samples
):
"""normalize weight
"""normalize weight
Parameters
Parameters
----------
----------
samples: list
samples: list
...
@@ -184,7 +184,7 @@ class CurveModel(object):
...
@@ -184,7 +184,7 @@ class CurveModel(object):
def
sigma_sq
(
self
,
sample
):
def
sigma_sq
(
self
,
sample
):
"""returns the value of sigma square, given the weight's sample
"""returns the value of sigma square, given the weight's sample
Parameters
Parameters
----------
----------
sample: list
sample: list
...
@@ -203,7 +203,7 @@ class CurveModel(object):
...
@@ -203,7 +203,7 @@ class CurveModel(object):
def
normal_distribution
(
self
,
pos
,
sample
):
def
normal_distribution
(
self
,
pos
,
sample
):
"""returns the value of normal distribution, given the weight's sample and target position
"""returns the value of normal distribution, given the weight's sample and target position
Parameters
Parameters
----------
----------
pos: int
pos: int
...
@@ -227,7 +227,7 @@ class CurveModel(object):
...
@@ -227,7 +227,7 @@ class CurveModel(object):
----------
----------
sample: list
sample: list
sample is a (1 * NUM_OF_FUNCTIONS) matrix, representing{w1, w2, ... wk}
sample is a (1 * NUM_OF_FUNCTIONS) matrix, representing{w1, w2, ... wk}
Returns
Returns
-------
-------
float
float
...
@@ -241,13 +241,13 @@ class CurveModel(object):
...
@@ -241,13 +241,13 @@ class CurveModel(object):
def
prior
(
self
,
samples
):
def
prior
(
self
,
samples
):
"""priori distribution
"""priori distribution
Parameters
Parameters
----------
----------
samples: list
samples: list
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
Returns
Returns
-------
-------
float
float
...
@@ -264,13 +264,13 @@ class CurveModel(object):
...
@@ -264,13 +264,13 @@ class CurveModel(object):
def
target_distribution
(
self
,
samples
):
def
target_distribution
(
self
,
samples
):
"""posterior probability
"""posterior probability
Parameters
Parameters
----------
----------
samples: list
samples: list
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
Returns
Returns
-------
-------
float
float
...
@@ -319,7 +319,7 @@ class CurveModel(object):
...
@@ -319,7 +319,7 @@ class CurveModel(object):
def
predict
(
self
,
trial_history
):
def
predict
(
self
,
trial_history
):
"""predict the value of target position
"""predict the value of target position
Parameters
Parameters
----------
----------
trial_history: list
trial_history: list
...
...
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
View file @
d48ad027
...
@@ -167,7 +167,7 @@ class EvolutionTuner(Tuner):
...
@@ -167,7 +167,7 @@ class EvolutionTuner(Tuner):
self
.
space
=
None
self
.
space
=
None
def
update_search_space
(
self
,
search_space
):
def
update_search_space
(
self
,
search_space
):
"""Update search space.
"""Update search space.
Search_space contains the information that user pre-defined.
Search_space contains the information that user pre-defined.
Parameters
Parameters
...
@@ -194,7 +194,7 @@ class EvolutionTuner(Tuner):
...
@@ -194,7 +194,7 @@ class EvolutionTuner(Tuner):
Parameters
Parameters
----------
----------
parameter_id : int
parameter_id : int
Returns
Returns
-------
-------
config : dict
config : dict
...
...
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
d48ad027
...
@@ -43,7 +43,7 @@ _epsilon = 1e-6
...
@@ -43,7 +43,7 @@ _epsilon = 1e-6
def
create_parameter_id
():
def
create_parameter_id
():
"""Create an id
"""Create an id
Returns
Returns
-------
-------
int
int
...
@@ -55,7 +55,7 @@ def create_parameter_id():
...
@@ -55,7 +55,7 @@ def create_parameter_id():
def
create_bracket_parameter_id
(
brackets_id
,
brackets_curr_decay
,
increased_id
=-
1
):
def
create_bracket_parameter_id
(
brackets_id
,
brackets_curr_decay
,
increased_id
=-
1
):
"""Create a full id for a specific bracket's hyperparameter configuration
"""Create a full id for a specific bracket's hyperparameter configuration
Parameters
Parameters
----------
----------
brackets_id: int
brackets_id: int
...
@@ -79,7 +79,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
...
@@ -79,7 +79,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
def
json2parameter
(
ss_spec
,
random_state
):
def
json2parameter
(
ss_spec
,
random_state
):
"""Randomly generate values for hyperparameters from hyperparameter space i.e., x.
"""Randomly generate values for hyperparameters from hyperparameter space i.e., x.
Parameters
Parameters
----------
----------
ss_spec:
ss_spec:
...
@@ -116,7 +116,7 @@ def json2parameter(ss_spec, random_state):
...
@@ -116,7 +116,7 @@ def json2parameter(ss_spec, random_state):
class
Bracket
():
class
Bracket
():
"""A bracket in Hyperband, all the information of a bracket is managed by an instance of this class
"""A bracket in Hyperband, all the information of a bracket is managed by an instance of this class
Parameters
Parameters
----------
----------
s: int
s: int
...
@@ -132,7 +132,7 @@ class Bracket():
...
@@ -132,7 +132,7 @@ class Bracket():
optimize_mode: str
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
optimize mode, 'maximize' or 'minimize'
"""
"""
def
__init__
(
self
,
s
,
s_max
,
eta
,
R
,
optimize_mode
):
def
__init__
(
self
,
s
,
s_max
,
eta
,
R
,
optimize_mode
):
self
.
bracket_id
=
s
self
.
bracket_id
=
s
self
.
s_max
=
s_max
self
.
s_max
=
s_max
...
@@ -163,7 +163,7 @@ class Bracket():
...
@@ -163,7 +163,7 @@ class Bracket():
def
set_config_perf
(
self
,
i
,
parameter_id
,
seq
,
value
):
def
set_config_perf
(
self
,
i
,
parameter_id
,
seq
,
value
):
"""update trial's latest result with its sequence number, e.g., epoch number or batch number
"""update trial's latest result with its sequence number, e.g., epoch number or batch number
Parameters
Parameters
----------
----------
i: int
i: int
...
@@ -184,7 +184,7 @@ class Bracket():
...
@@ -184,7 +184,7 @@ class Bracket():
self
.
configs_perf
[
i
][
parameter_id
]
=
[
seq
,
value
]
self
.
configs_perf
[
i
][
parameter_id
]
=
[
seq
,
value
]
else
:
else
:
self
.
configs_perf
[
i
][
parameter_id
]
=
[
seq
,
value
]
self
.
configs_perf
[
i
][
parameter_id
]
=
[
seq
,
value
]
def
inform_trial_end
(
self
,
i
):
def
inform_trial_end
(
self
,
i
):
"""If the trial is finished and the corresponding round (i.e., i) has all its trials finished,
"""If the trial is finished and the corresponding round (i.e., i) has all its trials finished,
...
@@ -230,7 +230,7 @@ class Bracket():
...
@@ -230,7 +230,7 @@ class Bracket():
----------
----------
num: int
num: int
the number of hyperparameter configurations
the number of hyperparameter configurations
Returns
Returns
-------
-------
list
list
...
@@ -350,7 +350,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -350,7 +350,7 @@ class Hyperband(MsgDispatcherBase):
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
"""data: JSON object, which is search space
Parameters
Parameters
----------
----------
data: int
data: int
...
@@ -392,9 +392,9 @@ class Hyperband(MsgDispatcherBase):
...
@@ -392,9 +392,9 @@ class Hyperband(MsgDispatcherBase):
"""
"""
Parameters
Parameters
----------
----------
data:
data:
it is an object which has keys 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
it is an object which has keys 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
Raises
Raises
------
------
ValueError
ValueError
...
...
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
View file @
d48ad027
...
@@ -21,10 +21,10 @@ from nni.assessor import Assessor, AssessResult
...
@@ -21,10 +21,10 @@ from nni.assessor import Assessor, AssessResult
logger
=
logging
.
getLogger
(
'medianstop_Assessor'
)
logger
=
logging
.
getLogger
(
'medianstop_Assessor'
)
class
MedianstopAssessor
(
Assessor
):
class
MedianstopAssessor
(
Assessor
):
"""MedianstopAssessor is The median stopping rule stops a pending trial X at step S
"""MedianstopAssessor is The median stopping rule stops a pending trial X at step S
if the trial’s best objective value by step S is strictly worse than the median value
if the trial’s best objective value by step S is strictly worse than the median value
of the running averages of all completed trials’ objectives reported up to step S
of the running averages of all completed trials’ objectives reported up to step S
Parameters
Parameters
----------
----------
optimize_mode: str
optimize_mode: str
...
@@ -60,7 +60,7 @@ class MedianstopAssessor(Assessor):
...
@@ -60,7 +60,7 @@ class MedianstopAssessor(Assessor):
def
trial_end
(
self
,
trial_job_id
,
success
):
def
trial_end
(
self
,
trial_job_id
,
success
):
"""trial_end
"""trial_end
Parameters
Parameters
----------
----------
trial_job_id: int
trial_job_id: int
...
@@ -83,7 +83,7 @@ class MedianstopAssessor(Assessor):
...
@@ -83,7 +83,7 @@ class MedianstopAssessor(Assessor):
def
assess_trial
(
self
,
trial_job_id
,
trial_history
):
def
assess_trial
(
self
,
trial_job_id
,
trial_history
):
"""assess_trial
"""assess_trial
Parameters
Parameters
----------
----------
trial_job_id: int
trial_job_id: int
...
...
src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py
View file @
d48ad027
...
@@ -27,7 +27,7 @@ from scipy.optimize import minimize
...
@@ -27,7 +27,7 @@ from scipy.optimize import minimize
import
nni.metis_tuner.lib_data
as
lib_data
import
nni.metis_tuner.lib_data
as
lib_data
def
next_hyperparameter_expected_improvement
(
fun_prediction
,
def
next_hyperparameter_expected_improvement
(
fun_prediction
,
fun_prediction_args
,
fun_prediction_args
,
x_bounds
,
x_types
,
x_bounds
,
x_types
,
samples_y_aggregation
,
samples_y_aggregation
,
...
...
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
View file @
d48ad027
...
@@ -69,7 +69,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -69,7 +69,7 @@ class NetworkMorphismTuner(Tuner):
optimize_mode : str
optimize_mode : str
optimize mode "minimize" or "maximize" (default: {"minimize"})
optimize mode "minimize" or "maximize" (default: {"minimize"})
path : str
path : str
default mode path to save the model file (default: {"model_path"})
default mode path to save the model file (default: {"model_path"})
verbose : bool
verbose : bool
verbose to print the log (default: {True})
verbose to print the log (default: {True})
beta : float
beta : float
...
@@ -154,7 +154,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -154,7 +154,7 @@ class NetworkMorphismTuner(Tuner):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
""" Record an observation of the objective function.
""" Record an observation of the objective function.
Parameters
Parameters
----------
----------
parameter_id : int
parameter_id : int
...
@@ -267,7 +267,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -267,7 +267,7 @@ class NetworkMorphismTuner(Tuner):
----------
----------
model_id : int
model_id : int
model index
model index
Returns
Returns
-------
-------
load_model : Graph
load_model : Graph
...
@@ -297,7 +297,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -297,7 +297,7 @@ class NetworkMorphismTuner(Tuner):
----------
----------
model_id : int
model_id : int
model index
model index
Returns
Returns
-------
-------
float
float
...
...
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment