Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6febf29
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "3bee8403fbceab909782505998cff6c35fad7a1a"
Commit
d6febf29
authored
Jun 25, 2019
by
suiguoxin
Browse files
Merge branch 'master' of
git://github.com/microsoft/nni
parents
77c95479
c2179921
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
152 additions
and
40 deletions
+152
-40
examples/tuners/random_nas_tuner/random_nas_tuner.py
examples/tuners/random_nas_tuner/random_nas_tuner.py
+2
-2
examples/tuners/weight_sharing/ga_customer_tuner/customer_tuner.py
...tuners/weight_sharing/ga_customer_tuner/customer_tuner.py
+2
-2
src/nni_manager/common/utils.ts
src/nni_manager/common/utils.ts
+1
-1
src/nni_manager/rest_server/restValidationSchemas.ts
src/nni_manager/rest_server/restValidationSchemas.ts
+1
-0
src/nni_manager/training_service/common/clusterJobRestServer.ts
...i_manager/training_service/common/clusterJobRestServer.ts
+5
-1
src/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+5
-3
src/nni_manager/training_service/pai/paiData.ts
src/nni_manager/training_service/pai/paiData.ts
+4
-4
src/nni_manager/training_service/pai/paiJobRestServer.ts
src/nni_manager/training_service/pai/paiJobRestServer.ts
+38
-0
src/nni_manager/training_service/pai/paiTrainingService.ts
src/nni_manager/training_service/pai/paiTrainingService.ts
+64
-6
src/nni_manager/types/tail-stream/index.d.ts
src/nni_manager/types/tail-stream/index.d.ts
+2
-1
src/sdk/pynni/nni/__init__.py
src/sdk/pynni/nni/__init__.py
+1
-0
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+4
-6
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+2
-2
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+1
-1
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+8
-0
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
+2
-2
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
+2
-2
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+1
-1
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+2
-2
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
+5
-4
No files found.
examples/tuners/random_nas_tuner/random_nas_tuner.py
View file @
d6febf29
...
@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner):
...
@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner):
self
.
searchspace_json
=
search_space
self
.
searchspace_json
=
search_space
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
'''generate
'''generate
'''
'''
return
random_archi_generator
(
self
.
searchspace_json
,
self
.
random_state
)
return
random_archi_generator
(
self
.
searchspace_json
,
self
.
random_state
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
'''receive
'''receive
'''
'''
pass
pass
examples/tuners/weight_sharing/ga_customer_tuner/customer_tuner.py
View file @
d6febf29
...
@@ -112,7 +112,7 @@ class CustomerTuner(Tuner):
...
@@ -112,7 +112,7 @@ class CustomerTuner(Tuner):
population
.
append
(
Individual
(
indiv_id
=
self
.
generate_new_id
(),
graph_cfg
=
graph_tmp
,
result
=
None
))
population
.
append
(
Individual
(
indiv_id
=
self
.
generate_new_id
(),
graph_cfg
=
graph_tmp
,
result
=
None
))
return
population
return
population
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a set of trial graph config, as a serializable object.
"""Returns a set of trial graph config, as a serializable object.
An example configuration:
An example configuration:
```json
```json
...
@@ -196,7 +196,7 @@ class CustomerTuner(Tuner):
...
@@ -196,7 +196,7 @@ class CustomerTuner(Tuner):
logger
.
debug
(
"trial {} ready"
.
format
(
indiv
.
indiv_id
))
logger
.
debug
(
"trial {} ready"
.
format
(
indiv
.
indiv_id
))
return
param_json
return
param_json
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
'''
'''
Record an observation of the objective function
Record an observation of the objective function
parameter_id : int
parameter_id : int
...
...
src/nni_manager/common/utils.ts
View file @
d6febf29
...
@@ -375,7 +375,7 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
...
@@ -375,7 +375,7 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
}
}
function
validateFileName
(
fileName
:
string
):
boolean
{
function
validateFileName
(
fileName
:
string
):
boolean
{
let
pattern
:
string
=
'
^[a-z0-9A-Z
\
.
-
_]+$
'
;
let
pattern
:
string
=
'
^[a-z0-9A-Z
\
._
-
]+$
'
;
const
validateResult
=
fileName
.
match
(
pattern
);
const
validateResult
=
fileName
.
match
(
pattern
);
if
(
validateResult
)
{
if
(
validateResult
)
{
return
true
;
return
true
;
...
...
src/nni_manager/rest_server/restValidationSchemas.ts
View file @
d6febf29
...
@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
...
@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
command
:
joi
.
string
().
min
(
1
),
command
:
joi
.
string
().
min
(
1
),
virtualCluster
:
joi
.
string
(),
virtualCluster
:
joi
.
string
(),
shmMB
:
joi
.
number
(),
shmMB
:
joi
.
number
(),
nasMode
:
joi
.
string
().
valid
(
'
classic_mode
'
,
'
enas_mode
'
,
'
oneshot_mode
'
),
worker
:
joi
.
object
({
worker
:
joi
.
object
({
replicas
:
joi
.
number
().
min
(
1
).
required
(),
replicas
:
joi
.
number
().
min
(
1
).
required
(),
image
:
joi
.
string
().
min
(
1
),
image
:
joi
.
string
().
min
(
1
),
...
...
src/nni_manager/training_service/common/clusterJobRestServer.ts
View file @
d6febf29
...
@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer {
...
@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer {
this
.
port
=
basePort
+
1
;
this
.
port
=
basePort
+
1
;
}
}
get
apiRootUrl
():
string
{
return
this
.
API_ROOT_URL
;
}
public
get
clusterRestServerPort
():
number
{
public
get
clusterRestServerPort
():
number
{
if
(
this
.
port
===
undefined
)
{
if
(
this
.
port
===
undefined
)
{
throw
new
Error
(
'
PAI Rest server port is undefined
'
);
throw
new
Error
(
'
PAI Rest server port is undefined
'
);
...
@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer {
...
@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer {
protected
abstract
handleTrialMetrics
(
jobId
:
string
,
trialMetrics
:
any
[])
:
void
;
protected
abstract
handleTrialMetrics
(
jobId
:
string
,
trialMetrics
:
any
[])
:
void
;
// tslint:disable: no-unsafe-any no-any
// tslint:disable: no-unsafe-any no-any
pr
ivate
createRestHandler
()
:
Router
{
pr
otected
createRestHandler
()
:
Router
{
const
router
:
Router
=
Router
();
const
router
:
Router
=
Router
();
router
.
use
((
req
:
Request
,
res
:
Response
,
next
:
any
)
=>
{
router
.
use
((
req
:
Request
,
res
:
Response
,
next
:
any
)
=>
{
...
...
src/nni_manager/training_service/local/localTrainingService.ts
View file @
d6febf29
...
@@ -355,7 +355,8 @@ class LocalTrainingService implements TrainingService {
...
@@ -355,7 +355,8 @@ class LocalTrainingService implements TrainingService {
this
.
log
.
info
(
'
Stopping local machine training service...
'
);
this
.
log
.
info
(
'
Stopping local machine training service...
'
);
this
.
stopping
=
true
;
this
.
stopping
=
true
;
for
(
const
stream
of
this
.
jobStreamMap
.
values
())
{
for
(
const
stream
of
this
.
jobStreamMap
.
values
())
{
stream
.
destroy
();
stream
.
end
(
0
)
stream
.
emit
(
'
end
'
)
}
}
if
(
this
.
gpuScheduler
!==
undefined
)
{
if
(
this
.
gpuScheduler
!==
undefined
)
{
await
this
.
gpuScheduler
.
stop
();
await
this
.
gpuScheduler
.
stop
();
...
@@ -372,7 +373,9 @@ class LocalTrainingService implements TrainingService {
...
@@ -372,7 +373,9 @@ class LocalTrainingService implements TrainingService {
if
(
stream
===
undefined
)
{
if
(
stream
===
undefined
)
{
throw
new
Error
(
`Could not find stream in trial
${
trialJob
.
id
}
`
);
throw
new
Error
(
`Could not find stream in trial
${
trialJob
.
id
}
`
);
}
}
stream
.
destroy
();
//Refer https://github.com/Juul/tail-stream/issues/20
stream
.
end
(
0
)
stream
.
emit
(
'
end
'
)
this
.
jobStreamMap
.
delete
(
trialJob
.
id
);
this
.
jobStreamMap
.
delete
(
trialJob
.
id
);
}
}
}
}
...
@@ -567,7 +570,6 @@ class LocalTrainingService implements TrainingService {
...
@@ -567,7 +570,6 @@ class LocalTrainingService implements TrainingService {
buffer
=
remain
;
buffer
=
remain
;
}
}
});
});
this
.
jobStreamMap
.
set
(
trialJobDetail
.
id
,
stream
);
this
.
jobStreamMap
.
set
(
trialJobDetail
.
id
,
stream
);
}
}
...
...
src/nni_manager/training_service/pai/paiData.ts
View file @
d6febf29
...
@@ -64,11 +64,11 @@ else
...
@@ -64,11 +64,11 @@ else
fi`
;
fi`
;
export
const
PAI_TRIAL_COMMAND_FORMAT
:
string
=
export
const
PAI_TRIAL_COMMAND_FORMAT
:
string
=
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
MULTI_PHASE={5}
\
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{
5
}' --nnimanager_ip '{
6
}' --nnimanager_port '{
7
}' \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{
6
}' --nnimanager_ip '{
7
}' --nnimanager_port '{
8
}' \
--pai_hdfs_output_dir '{
8
}' --pai_hdfs_host '{
9
}' --pai_user_name {1
0
} --nni_hdfs_exp_dir '{1
1
}' --webhdfs_path '/webhdfs/api/v1' \
--pai_hdfs_output_dir '{
9
}' --pai_hdfs_host '{
10
}' --pai_user_name {1
1
} --nni_hdfs_exp_dir '{1
2
}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{1
2
}' --log_collection '{1
3
}'`
;
--nni_manager_version '{1
3
}' --log_collection '{1
4
}'`
;
export
const
PAI_OUTPUT_DIR_FORMAT
:
string
=
export
const
PAI_OUTPUT_DIR_FORMAT
:
string
=
`hdfs://{0}:9000/`
;
`hdfs://{0}:9000/`
;
...
...
src/nni_manager/training_service/pai/paiJobRestServer.ts
View file @
d6febf29
...
@@ -19,17 +19,26 @@
...
@@ -19,17 +19,26 @@
'
use strict
'
;
'
use strict
'
;
import
{
Request
,
Response
,
Router
}
from
'
express
'
;
import
{
Inject
}
from
'
typescript-ioc
'
;
import
{
Inject
}
from
'
typescript-ioc
'
;
import
*
as
component
from
'
../../common/component
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
ClusterJobRestServer
}
from
'
../common/clusterJobRestServer
'
;
import
{
ClusterJobRestServer
}
from
'
../common/clusterJobRestServer
'
;
import
{
PAITrainingService
}
from
'
./paiTrainingService
'
;
import
{
PAITrainingService
}
from
'
./paiTrainingService
'
;
export
interface
ParameterFileMeta
{
readonly
experimentId
:
string
;
readonly
trialId
:
string
;
readonly
filePath
:
string
;
}
/**
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*
*/
*/
@
component
.
Singleton
@
component
.
Singleton
export
class
PAIJobRestServer
extends
ClusterJobRestServer
{
export
class
PAIJobRestServer
extends
ClusterJobRestServer
{
private
parameterFileMetaList
:
ParameterFileMeta
[]
=
[];
@
Inject
@
Inject
private
readonly
paiTrainingService
:
PAITrainingService
;
private
readonly
paiTrainingService
:
PAITrainingService
;
...
@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
...
@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
});
});
}
}
}
}
protected
createRestHandler
():
Router
{
const
router
:
Router
=
super
.
createRestHandler
();
router
.
post
(
`/parameter-file-meta`
,
(
req
:
Request
,
res
:
Response
)
=>
{
try
{
this
.
log
.
info
(
`POST /parameter-file-meta, body is
${
JSON
.
stringify
(
req
.
body
)}
`
);
this
.
parameterFileMetaList
.
push
(
req
.
body
);
res
.
send
();
}
catch
(
err
)
{
this
.
log
.
error
(
`POST parameter-file-meta error:
${
err
}
`
);
res
.
status
(
500
);
res
.
send
(
err
.
message
);
}
});
router
.
get
(
`/parameter-file-meta`
,
(
req
:
Request
,
res
:
Response
)
=>
{
try
{
this
.
log
.
info
(
`GET /parameter-file-meta`
);
res
.
send
(
this
.
parameterFileMetaList
);
}
catch
(
err
)
{
this
.
log
.
error
(
`GET parameter-file-meta error:
${
err
}
`
);
res
.
status
(
500
);
res
.
send
(
err
.
message
);
}
});
return
router
;
}
}
}
src/nni_manager/training_service/pai/paiTrainingService.ts
View file @
d6febf29
...
@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
...
@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import
{
getExperimentId
,
getInitTrialSequenceId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getExperimentId
,
getInitTrialSequenceId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
import
{
JobApplicationForm
,
NNIManagerIpConfig
,
TrainingService
,
HyperParameters
,
JobApplicationForm
,
NNIManagerIpConfig
,
TrainingService
,
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
TrialJobApplicationForm
,
TrialJobDetail
,
TrialJobMetric
}
from
'
../../common/trainingService
'
;
}
from
'
../../common/trainingService
'
;
import
{
delay
,
generateParamFileName
,
import
{
delay
,
generateParamFileName
,
...
@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
...
@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import
{
NNIPAITrialConfig
,
PAIClusterConfig
,
PAIJobConfig
,
PAITaskRole
}
from
'
./paiConfig
'
;
import
{
NNIPAITrialConfig
,
PAIClusterConfig
,
PAIJobConfig
,
PAITaskRole
}
from
'
./paiConfig
'
;
import
{
PAI_LOG_PATH_FORMAT
,
PAI_OUTPUT_DIR_FORMAT
,
PAI_TRIAL_COMMAND_FORMAT
,
PAITrialJobDetail
}
from
'
./paiData
'
;
import
{
PAI_LOG_PATH_FORMAT
,
PAI_OUTPUT_DIR_FORMAT
,
PAI_TRIAL_COMMAND_FORMAT
,
PAITrialJobDetail
}
from
'
./paiData
'
;
import
{
PAIJobInfoCollector
}
from
'
./paiJobInfoCollector
'
;
import
{
PAIJobInfoCollector
}
from
'
./paiJobInfoCollector
'
;
import
{
PAIJobRestServer
}
from
'
./paiJobRestServer
'
;
import
{
PAIJobRestServer
,
ParameterFileMeta
}
from
'
./paiJobRestServer
'
;
import
*
as
WebHDFS
from
'
webhdfs
'
;
import
*
as
WebHDFS
from
'
webhdfs
'
;
...
@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
...
@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
private
versionCheck
:
boolean
=
true
;
private
versionCheck
:
boolean
=
true
;
private
logCollection
:
string
;
private
logCollection
:
string
;
private
isMultiPhase
:
boolean
=
false
;
constructor
()
{
constructor
()
{
this
.
log
=
getLogger
();
this
.
log
=
getLogger
();
...
@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
...
@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
public
updateTrialJob
(
trialJobId
:
string
,
form
:
JobApplicationForm
):
Promise
<
TrialJobDetail
>
{
public
async
updateTrialJob
(
trialJobId
:
string
,
form
:
JobApplicationForm
):
Promise
<
TrialJobDetail
>
{
throw
new
MethodNotImplementedError
();
const
trialJobDetail
:
undefined
|
TrialJobDetail
=
this
.
trialJobsMap
.
get
(
trialJobId
);
if
(
trialJobDetail
===
undefined
)
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
not found`
);
}
if
(
form
.
jobType
===
'
TRIAL
'
)
{
await
this
.
writeParameterFile
(
trialJobId
,
(
<
TrialJobApplicationForm
>
form
).
hyperParameters
);
}
else
{
throw
new
Error
(
`updateTrialJob failed: jobType
${
form
.
jobType
}
not supported.`
);
}
return
trialJobDetail
;
}
}
public
get
isMultiPhaseJobSupported
():
boolean
{
public
get
isMultiPhaseJobSupported
():
boolean
{
return
fals
e
;
return
tru
e
;
}
}
// tslint:disable:no-http-string
// tslint:disable:no-http-string
...
@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
...
@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case
TrialConfigMetadataKey
.
LOG_COLLECTION
:
case
TrialConfigMetadataKey
.
LOG_COLLECTION
:
this
.
logCollection
=
value
;
this
.
logCollection
=
value
;
break
;
break
;
case
TrialConfigMetadataKey
.
MULTI_PHASE
:
this
.
isMultiPhase
=
(
value
===
'
true
'
||
value
===
'
True
'
);
break
;
default
:
default
:
//Reject for unknown keys
//Reject for unknown keys
throw
new
Error
(
`Uknown key:
${
key
}
`
);
throw
new
Error
(
`Uknown key:
${
key
}
`
);
...
@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
...
@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId
,
trialJobId
,
this
.
experimentId
,
this
.
experimentId
,
trialJobDetail
.
sequenceId
,
trialJobDetail
.
sequenceId
,
this
.
isMultiPhase
,
this
.
paiTrialConfig
.
command
,
this
.
paiTrialConfig
.
command
,
nniManagerIp
,
nniManagerIp
,
this
.
paiRestServerPort
,
this
.
paiRestServerPort
,
...
@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
...
@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return
Promise
.
race
([
timeoutDelay
,
deferred
.
promise
])
return
Promise
.
race
([
timeoutDelay
,
deferred
.
promise
])
.
finally
(()
=>
{
clearTimeout
(
timeoutId
);
});
.
finally
(()
=>
{
clearTimeout
(
timeoutId
);
});
}
}
// tslint:enable:no-any no-unsafe-any no-http-string
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
):
Promise
<
void
>
{
if
(
this
.
paiClusterConfig
===
undefined
)
{
throw
new
Error
(
'
PAI Cluster config is not initialized
'
);
}
if
(
this
.
paiTrialConfig
===
undefined
)
{
throw
new
Error
(
'
PAI trial config is not initialized
'
);
}
const
trialLocalTempFolder
:
string
=
path
.
join
(
getExperimentRootDir
(),
'
trials-local
'
,
trialJobId
);
const
hpFileName
:
string
=
generateParamFileName
(
hyperParameters
);
const
localFilepath
:
string
=
path
.
join
(
trialLocalTempFolder
,
hpFileName
);
await
fs
.
promises
.
writeFile
(
localFilepath
,
hyperParameters
.
value
,
{
encoding
:
'
utf8
'
});
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsHpFilePath
:
string
=
path
.
join
(
hdfsCodeDir
,
hpFileName
);
await
HDFSClientUtility
.
copyFileToHdfs
(
localFilepath
,
hdfsHpFilePath
,
this
.
hdfsClient
);
await
this
.
postParameterFileMeta
({
experimentId
:
this
.
experimentId
,
trialId
:
trialJobId
,
filePath
:
hdfsHpFilePath
});
}
private
postParameterFileMeta
(
parameterFileMeta
:
ParameterFileMeta
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
req
:
request
.
Options
=
{
uri
:
`
${
restServer
.
endPoint
}${
restServer
.
apiRootUrl
}
/parameter-file-meta`
,
method
:
'
POST
'
,
json
:
true
,
body
:
parameterFileMeta
};
request
(
req
,
(
err
:
Error
,
res
:
request
.
Response
)
=>
{
if
(
err
)
{
deferred
.
reject
(
err
);
}
else
{
deferred
.
resolve
();
}
});
return
deferred
.
promise
;
}
}
}
export
{
PAITrainingService
};
export
{
PAITrainingService
};
src/nni_manager/types/tail-stream/index.d.ts
View file @
d6febf29
declare
module
'
tail-stream
'
{
declare
module
'
tail-stream
'
{
export
interface
Stream
{
export
interface
Stream
{
on
(
type
:
'
data
'
,
callback
:
(
data
:
Buffer
)
=>
void
):
void
;
on
(
type
:
'
data
'
,
callback
:
(
data
:
Buffer
)
=>
void
):
void
;
destroy
():
void
;
end
(
data
:
number
):
void
;
emit
(
data
:
string
):
void
;
}
}
export
function
createReadStream
(
path
:
string
):
Stream
;
export
function
createReadStream
(
path
:
string
):
Stream
;
}
}
\ No newline at end of file
src/sdk/pynni/nni/__init__.py
View file @
d6febf29
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
from
.trial
import
*
from
.trial
import
*
from
.smartparam
import
*
from
.smartparam
import
*
from
.nas_utils
import
reload_tensorflow_variables
class
NoMoreTrialError
(
Exception
):
class
NoMoreTrialError
(
Exception
):
def
__init__
(
self
,
ErrorInfo
):
def
__init__
(
self
,
ErrorInfo
):
...
...
src/sdk/pynni/nni/__main__.py
View file @
d6febf29
...
@@ -28,9 +28,8 @@ import json
...
@@ -28,9 +28,8 @@ import json
import
importlib
import
importlib
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
,
AdvisorModuleName
,
AdvisorClassName
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
,
AdvisorModuleName
,
AdvisorClassName
from
nni.common
import
enable_multi_thread
from
nni.common
import
enable_multi_thread
,
enable_multi_phase
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
.
debug
(
'START'
)
logger
.
debug
(
'START'
)
...
@@ -126,6 +125,8 @@ def main():
...
@@ -126,6 +125,8 @@ def main():
args
=
parse_args
()
args
=
parse_args
()
if
args
.
multi_thread
:
if
args
.
multi_thread
:
enable_multi_thread
()
enable_multi_thread
()
if
args
.
multi_phase
:
enable_multi_phase
()
if
args
.
advisor_class_name
:
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
# advisor is enabled and starts to run
...
@@ -180,10 +181,7 @@ def main():
...
@@ -180,10 +181,7 @@ def main():
if
assessor
is
None
:
if
assessor
is
None
:
raise
AssertionError
(
'Failed to create Assessor instance'
)
raise
AssertionError
(
'Failed to create Assessor instance'
)
if
args
.
multi_phase
:
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
,
assessor
)
else
:
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
try
:
try
:
dispatcher
.
run
()
dispatcher
.
run
()
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
d6febf29
...
@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
...
@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
"""
self
.
values
=
self
.
is_valid
(
search_space
)
self
.
values
=
self
.
is_valid
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
Parameters
...
@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
...
@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
return
self
.
values
[
self
.
count
]
return
self
.
values
[
self
.
count
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
pass
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
d6febf29
...
@@ -106,7 +106,7 @@ class Bracket():
...
@@ -106,7 +106,7 @@ class Bracket():
self
.
s_max
=
s_max
self
.
s_max
=
s_max
self
.
eta
=
eta
self
.
eta
=
eta
self
.
max_budget
=
max_budget
self
.
max_budget
=
max_budget
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
eta
**
s
/
(
s
+
1
)
-
_epsilon
)
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
eta
**
s
/
(
s
+
1
)
-
_epsilon
)
self
.
r
=
max_budget
/
eta
**
s
self
.
r
=
max_budget
/
eta
**
s
...
...
src/sdk/pynni/nni/common.py
View file @
d6febf29
...
@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
...
@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
_multi_thread
=
False
_multi_thread
=
False
_multi_phase
=
False
def
enable_multi_thread
():
def
enable_multi_thread
():
global
_multi_thread
global
_multi_thread
...
@@ -76,3 +77,10 @@ def enable_multi_thread():
...
@@ -76,3 +77,10 @@ def enable_multi_thread():
def
multi_thread_enabled
():
def
multi_thread_enabled
():
return
_multi_thread
return
_multi_thread
def
enable_multi_phase
():
global
_multi_phase
_multi_phase
=
True
def
multi_phase_enabled
():
return
_multi_phase
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
View file @
d6febf29
...
@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
...
@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
population
.
append
(
Individual
(
config
=
config
))
self
.
population
.
append
(
Individual
(
config
=
config
))
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
Parameters
...
@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
...
@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config
=
split_index
(
total_config
)
config
=
split_index
(
total_config
)
return
config
return
config
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
'''Record the result from a trial
'''Record the result from a trial
Parameters
Parameters
...
...
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
View file @
d6febf29
...
@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
...
@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
'''
self
.
expanded_search_space
=
self
.
json2parameter
(
search_space
)
self
.
expanded_search_space
=
self
.
json2parameter
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
self
.
count
+=
1
self
.
count
+=
1
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
...
@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
...
@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return
self
.
expanded_search_space
[
self
.
count
]
return
self
.
expanded_search_space
[
self
.
count
]
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
pass
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
d6febf29
...
@@ -144,7 +144,7 @@ class Bracket():
...
@@ -144,7 +144,7 @@ class Bracket():
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
self
.
configs_perf
=
[]
# [ {id: [seq, acc]}, {}, ... ]
self
.
num_configs_to_run
=
[]
# [ n, n, n, ... ]
self
.
num_configs_to_run
=
[]
# [ n, n, n, ... ]
self
.
num_finished_configs
=
[]
# [ n, n, n, ... ]
self
.
num_finished_configs
=
[]
# [ n, n, n, ... ]
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
no_more_trial
=
False
self
.
no_more_trial
=
False
def
is_completed
(
self
):
def
is_completed
(
self
):
...
...
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
d6febf29
...
@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
...
@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose
=
0
)
verbose
=
0
)
self
.
rval
.
catch_eval_exceptions
=
False
self
.
rval
.
catch_eval_exceptions
=
False
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
Returns a set of trial (hyper-)parameters, as a serializable object.
...
@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
...
@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params
=
split_index
(
total_params
)
params
=
split_index
(
total_params
)
return
params
return
params
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
"""
Record an observation of the objective function
Record an observation of the objective function
...
...
src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py
View file @
d6febf29
...
@@ -49,15 +49,16 @@ def selection_r(x_bounds,
...
@@ -49,15 +49,16 @@ def selection_r(x_bounds,
num_starting_points
=
100
,
num_starting_points
=
100
,
minimize_constraints_fun
=
None
):
minimize_constraints_fun
=
None
):
'''
'''
Call selection
Select using different types.
'''
'''
minimize_starting_points
=
[
lib_data
.
rand
(
x_bounds
,
x_type
s
)
\
minimize_starting_points
=
clusteringmodel_gmm_good
.
sample
(
n_samples
=
num_starting_point
s
)
for
i
in
range
(
0
,
num_starting_points
)]
outputs
=
selection
(
x_bounds
,
x_types
,
outputs
=
selection
(
x_bounds
,
x_types
,
clusteringmodel_gmm_good
,
clusteringmodel_gmm_good
,
clusteringmodel_gmm_bad
,
clusteringmodel_gmm_bad
,
minimize_starting_points
,
minimize_starting_points
[
0
]
,
minimize_constraints_fun
)
minimize_constraints_fun
)
return
outputs
return
outputs
def
selection
(
x_bounds
,
def
selection
(
x_bounds
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment