Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f9ee589c
Unverified
Commit
f9ee589c
authored
Dec 24, 2019
by
SparkSnail
Committed by
GitHub
Dec 24, 2019
Browse files
Merge pull request #222 from microsoft/master
merge master
parents
36e6e350
4f3ee9cb
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
108 additions
and
27 deletions
+108
-27
src/webui/src/static/model/experiment.ts
src/webui/src/static/model/experiment.ts
+3
-0
src/webui/src/static/model/trial.ts
src/webui/src/static/model/trial.ts
+7
-0
src/webui/src/static/model/trialmanager.ts
src/webui/src/static/model/trialmanager.ts
+8
-0
test/generate_ts_config.py
test/generate_ts_config.py
+3
-3
test/training_service.yml
test/training_service.yml
+2
-2
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+43
-4
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+26
-0
tools/nni_cmd/launcher_utils.py
tools/nni_cmd/launcher_utils.py
+7
-12
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+8
-5
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+1
-1
No files found.
src/webui/src/static/model/experiment.ts
View file @
f9ee589c
...
@@ -41,6 +41,7 @@ class Experiment {
...
@@ -41,6 +41,7 @@ class Experiment {
if
(
!
this
.
profileField
)
{
if
(
!
this
.
profileField
)
{
throw
Error
(
'
Experiment profile not initialized
'
);
throw
Error
(
'
Experiment profile not initialized
'
);
}
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
profileField
!
;
return
this
.
profileField
!
;
}
}
...
@@ -73,6 +74,7 @@ class Experiment {
...
@@ -73,6 +74,7 @@ class Experiment {
if
(
!
this
.
statusField
)
{
if
(
!
this
.
statusField
)
{
throw
Error
(
'
Experiment status not initialized
'
);
throw
Error
(
'
Experiment status not initialized
'
);
}
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
statusField
!
.
status
;
return
this
.
statusField
!
.
status
;
}
}
...
@@ -80,6 +82,7 @@ class Experiment {
...
@@ -80,6 +82,7 @@ class Experiment {
if
(
!
this
.
statusField
)
{
if
(
!
this
.
statusField
)
{
throw
Error
(
'
Experiment status not initialized
'
);
throw
Error
(
'
Experiment status not initialized
'
);
}
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
statusField
!
.
errors
[
0
]
||
''
;
return
this
.
statusField
!
.
errors
[
0
]
||
''
;
}
}
}
}
...
...
src/webui/src/static/model/trial.ts
View file @
f9ee589c
...
@@ -19,10 +19,12 @@ class Trial implements TableObj {
...
@@ -19,10 +19,12 @@ class Trial implements TableObj {
if
(
!
this
.
sortable
||
!
otherTrial
.
sortable
)
{
if
(
!
this
.
sortable
||
!
otherTrial
.
sortable
)
{
return
undefined
;
return
undefined
;
}
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
finalAcc
!
-
otherTrial
.
finalAcc
!
;
return
this
.
finalAcc
!
-
otherTrial
.
finalAcc
!
;
}
}
get
info
():
TrialJobInfo
{
get
info
():
TrialJobInfo
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
infoField
!
;
return
this
.
infoField
!
;
}
}
...
@@ -30,6 +32,7 @@ class Trial implements TableObj {
...
@@ -30,6 +32,7 @@ class Trial implements TableObj {
const
ret
:
MetricDataRecord
[]
=
[
];
const
ret
:
MetricDataRecord
[]
=
[
];
for
(
let
i
=
0
;
i
<
this
.
intermediates
.
length
;
i
++
)
{
for
(
let
i
=
0
;
i
<
this
.
intermediates
.
length
;
i
++
)
{
if
(
this
.
intermediates
[
i
])
{
if
(
this
.
intermediates
[
i
])
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret
.
push
(
this
.
intermediates
[
i
]
!
);
ret
.
push
(
this
.
intermediates
[
i
]
!
);
}
else
{
}
else
{
break
;
break
;
...
@@ -66,12 +69,14 @@ class Trial implements TableObj {
...
@@ -66,12 +69,14 @@ class Trial implements TableObj {
get
tableRecord
():
TableRecord
{
get
tableRecord
():
TableRecord
{
const
endTime
=
this
.
info
.
endTime
||
new
Date
().
getTime
();
const
endTime
=
this
.
info
.
endTime
||
new
Date
().
getTime
();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const
duration
=
(
endTime
-
this
.
info
.
startTime
!
)
/
1000
;
const
duration
=
(
endTime
-
this
.
info
.
startTime
!
)
/
1000
;
return
{
return
{
key
:
this
.
info
.
id
,
key
:
this
.
info
.
id
,
sequenceId
:
this
.
info
.
sequenceId
,
sequenceId
:
this
.
info
.
sequenceId
,
id
:
this
.
info
.
id
,
id
:
this
.
info
.
id
,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
startTime
:
this
.
info
.
startTime
!
,
startTime
:
this
.
info
.
startTime
!
,
endTime
:
this
.
info
.
endTime
,
endTime
:
this
.
info
.
endTime
,
duration
,
duration
,
...
@@ -97,6 +102,7 @@ class Trial implements TableObj {
...
@@ -97,6 +102,7 @@ class Trial implements TableObj {
get
duration
():
number
{
get
duration
():
number
{
const
endTime
=
this
.
info
.
endTime
||
new
Date
().
getTime
();
const
endTime
=
this
.
info
.
endTime
||
new
Date
().
getTime
();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
(
endTime
-
this
.
info
.
startTime
!
)
/
1000
;
return
(
endTime
-
this
.
info
.
startTime
!
)
/
1000
;
}
}
...
@@ -203,6 +209,7 @@ class Trial implements TableObj {
...
@@ -203,6 +209,7 @@ class Trial implements TableObj {
}
else
if
(
this
.
intermediates
.
length
===
0
)
{
}
else
if
(
this
.
intermediates
.
length
===
0
)
{
return
'
--
'
;
return
'
--
'
;
}
else
{
}
else
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const
latest
=
this
.
intermediates
[
this
.
intermediates
.
length
-
1
]
!
;
const
latest
=
this
.
intermediates
[
this
.
intermediates
.
length
-
1
]
!
;
return
`
${
formatAccuracy
(
metricAccuracy
(
latest
))}
(LATEST)`
;
return
`
${
formatAccuracy
(
metricAccuracy
(
latest
))}
(LATEST)`
;
}
}
...
...
src/webui/src/static/model/trialmanager.ts
View file @
f9ee589c
...
@@ -7,6 +7,7 @@ function groupMetricsByTrial(metrics: MetricDataRecord[]): Map<string, MetricDat
...
@@ -7,6 +7,7 @@ function groupMetricsByTrial(metrics: MetricDataRecord[]): Map<string, MetricDat
const
ret
=
new
Map
<
string
,
MetricDataRecord
[]
>
();
const
ret
=
new
Map
<
string
,
MetricDataRecord
[]
>
();
for
(
const
metric
of
metrics
)
{
for
(
const
metric
of
metrics
)
{
if
(
ret
.
has
(
metric
.
trialJobId
))
{
if
(
ret
.
has
(
metric
.
trialJobId
))
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret
.
get
(
metric
.
trialJobId
)
!
.
push
(
metric
);
ret
.
get
(
metric
.
trialJobId
)
!
.
push
(
metric
);
}
else
{
}
else
{
ret
.
set
(
metric
.
trialJobId
,
[
metric
]);
ret
.
set
(
metric
.
trialJobId
,
[
metric
]);
...
@@ -35,14 +36,17 @@ class TrialManager {
...
@@ -35,14 +36,17 @@ class TrialManager {
}
}
public
getTrial
(
trialId
:
string
):
Trial
{
public
getTrial
(
trialId
:
string
):
Trial
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
trials
.
get
(
trialId
)
!
;
return
this
.
trials
.
get
(
trialId
)
!
;
}
}
public
getTrials
(
trialIds
:
string
[]):
Trial
[]
{
public
getTrials
(
trialIds
:
string
[]):
Trial
[]
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
trialIds
.
map
(
trialId
=>
this
.
trials
.
get
(
trialId
)
!
);
return
trialIds
.
map
(
trialId
=>
this
.
trials
.
get
(
trialId
)
!
);
}
}
public
table
(
trialIds
:
string
[]):
TableRecord
[]
{
public
table
(
trialIds
:
string
[]):
TableRecord
[]
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
trialIds
.
map
(
trialId
=>
this
.
trials
.
get
(
trialId
)
!
.
tableRecord
);
return
trialIds
.
map
(
trialId
=>
this
.
trials
.
get
(
trialId
)
!
.
tableRecord
);
}
}
...
@@ -61,6 +65,7 @@ class TrialManager {
...
@@ -61,6 +65,7 @@ class TrialManager {
}
}
public
sort
():
Trial
[]
{
public
sort
():
Trial
[]
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return
this
.
filter
(
trial
=>
trial
.
sortable
).
sort
((
trial1
,
trial2
)
=>
trial1
.
compareAccuracy
(
trial2
)
!
);
return
this
.
filter
(
trial
=>
trial
.
sortable
).
sort
((
trial1
,
trial2
)
=>
trial1
.
compareAccuracy
(
trial2
)
!
);
}
}
...
@@ -77,6 +82,7 @@ class TrialManager {
...
@@ -77,6 +82,7 @@ class TrialManager {
]);
]);
for
(
const
trial
of
this
.
trials
.
values
())
{
for
(
const
trial
of
this
.
trials
.
values
())
{
if
(
trial
.
initialized
())
{
if
(
trial
.
initialized
())
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
cnt
.
set
(
trial
.
info
.
status
,
cnt
.
get
(
trial
.
info
.
status
)
!
+
1
);
cnt
.
set
(
trial
.
info
.
status
,
cnt
.
get
(
trial
.
info
.
status
)
!
+
1
);
}
}
}
}
...
@@ -89,6 +95,7 @@ class TrialManager {
...
@@ -89,6 +95,7 @@ class TrialManager {
if
(
response
.
status
===
200
)
{
if
(
response
.
status
===
200
)
{
for
(
const
info
of
response
.
data
as
TrialJobInfo
[])
{
for
(
const
info
of
response
.
data
as
TrialJobInfo
[])
{
if
(
this
.
trials
.
has
(
info
.
id
))
{
if
(
this
.
trials
.
has
(
info
.
id
))
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
updated
=
this
.
trials
.
get
(
info
.
id
)
!
.
updateTrialJobInfo
(
info
)
||
updated
;
updated
=
this
.
trials
.
get
(
info
.
id
)
!
.
updateTrialJobInfo
(
info
)
||
updated
;
}
else
{
}
else
{
this
.
trials
.
set
(
info
.
id
,
new
Trial
(
info
,
undefined
));
this
.
trials
.
set
(
info
.
id
,
new
Trial
(
info
,
undefined
));
...
@@ -141,6 +148,7 @@ class TrialManager {
...
@@ -141,6 +148,7 @@ class TrialManager {
let
updated
=
false
;
let
updated
=
false
;
for
(
const
[
trialId
,
metrics
]
of
groupMetricsByTrial
(
allMetrics
).
entries
())
{
for
(
const
[
trialId
,
metrics
]
of
groupMetricsByTrial
(
allMetrics
).
entries
())
{
if
(
this
.
trials
.
has
(
trialId
))
{
if
(
this
.
trials
.
has
(
trialId
))
{
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const
trial
=
this
.
trials
.
get
(
trialId
)
!
;
const
trial
=
this
.
trials
.
get
(
trialId
)
!
;
updated
=
(
latestOnly
?
trial
.
updateLatestMetrics
(
metrics
)
:
trial
.
updateMetrics
(
metrics
))
||
updated
;
updated
=
(
latestOnly
?
trial
.
updateLatestMetrics
(
metrics
)
:
trial
.
updateMetrics
(
metrics
))
||
updated
;
}
else
{
}
else
{
...
...
test/generate_ts_config.py
View file @
f9ee589c
...
@@ -14,11 +14,11 @@ def update_training_service_config(args):
...
@@ -14,11 +14,11 @@ def update_training_service_config(args):
config
[
args
.
ts
][
'nniManagerIp'
]
=
args
.
nni_manager_ip
config
[
args
.
ts
][
'nniManagerIp'
]
=
args
.
nni_manager_ip
if
args
.
ts
==
'pai'
:
if
args
.
ts
==
'pai'
:
if
args
.
pai_user
is
not
None
:
if
args
.
pai_user
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'userName'
]
=
args
.
pai_user
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'userName'
]
=
args
.
pai_user
if
args
.
pai_pwd
is
not
None
:
if
args
.
pai_pwd
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'passWord'
]
=
args
.
pai_pwd
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'passWord'
]
=
args
.
pai_pwd
if
args
.
pai_host
is
not
None
:
if
args
.
pai_host
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'host'
]
=
args
.
pai_host
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'host'
]
=
args
.
pai_host
if
args
.
nni_docker_image
is
not
None
:
if
args
.
nni_docker_image
is
not
None
:
config
[
args
.
ts
][
'trial'
][
'image'
]
=
args
.
nni_docker_image
config
[
args
.
ts
][
'trial'
][
'image'
]
=
args
.
nni_docker_image
if
args
.
data_dir
is
not
None
:
if
args
.
data_dir
is
not
None
:
...
...
test/training_service.yml
View file @
f9ee589c
...
@@ -29,11 +29,11 @@ local:
...
@@ -29,11 +29,11 @@ local:
pai
:
pai
:
nniManagerIp
:
nniManagerIp
:
maxExecDuration
:
15m
maxExecDuration
:
15m
paiConfig
:
pai
Yarn
Config
:
host
:
host
:
passWord
:
passWord
:
userName
:
userName
:
trainingServicePlatform
:
pai
trainingServicePlatform
:
pai
Yarn
trial
:
trial
:
gpuNum
:
1
gpuNum
:
1
cpuNum
:
1
cpuNum
:
1
...
...
tools/nni_cmd/config_schema.py
View file @
f9ee589c
...
@@ -32,7 +32,7 @@ common_schema = {
...
@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency'
:
setNumberRange
(
'trialConcurrency'
,
int
,
1
,
99999
),
'trialConcurrency'
:
setNumberRange
(
'trialConcurrency'
,
int
,
1
,
99999
),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
@@ -53,14 +53,23 @@ common_schema = {
...
@@ -53,14 +53,23 @@ common_schema = {
}
}
}
}
tuner_schema_dict
=
{
tuner_schema_dict
=
{
(
'Anneal'
,
'SMAC'
)
:
{
'Anneal'
:
{
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'Anneal'
,
'SMAC'
)
,
'builtinTunerName'
:
'Anneal'
,
Optional
(
'classArgs'
):
{
Optional
(
'classArgs'
):
{
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
},
},
'SMAC'
:
{
'builtinTunerName'
:
'SMAC'
,
Optional
(
'classArgs'
):
{
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
'config_dedup'
:
setType
(
'config_dedup'
,
bool
)
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
},
(
'Evolution'
):
{
(
'Evolution'
):
{
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'Evolution'
),
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'Evolution'
),
Optional
(
'classArgs'
):
{
Optional
(
'classArgs'
):
{
...
@@ -223,7 +232,7 @@ common_trial_schema = {
...
@@ -223,7 +232,7 @@ common_trial_schema = {
}
}
}
}
pai_trial_schema
=
{
pai_
yarn_
trial_schema
=
{
'trial'
:{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
@@ -247,6 +256,34 @@ pai_trial_schema = {
...
@@ -247,6 +256,34 @@ pai_trial_schema = {
}
}
}
}
pai_yarn_config_schema
=
{
'paiYarnConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
'passWord'
:
setType
(
'passWord'
,
str
),
'host'
:
setType
(
'host'
,
str
)
},
{
'userName'
:
setType
(
'userName'
,
str
),
'token'
:
setType
(
'token'
,
str
),
'host'
:
setType
(
'host'
,
str
)
})
}
pai_trial_schema
=
{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
'cpuNum'
:
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
'memoryMB'
:
setType
(
'memoryMB'
,
int
),
'image'
:
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
'nniManagerNFSMountPath'
:
setPathCheck
(
'nniManagerNFSMountPath'
),
'containerNFSMountPath'
:
setType
(
'containerNFSMountPath'
,
str
),
'paiStoragePlugin'
:
setType
(
'paiStoragePlugin'
,
str
)
}
}
pai_config_schema
=
{
pai_config_schema
=
{
'paiConfig'
:
Or
({
'paiConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
'userName'
:
setType
(
'userName'
,
str
),
...
@@ -396,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
...
@@ -396,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_trial_schema
,
**
pai_config_schema
})
PAI_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_trial_schema
,
**
pai_config_schema
})
PAI_YARN_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
})
KUBEFLOW_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
})
KUBEFLOW_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
})
tools/nni_cmd/launcher.py
View file @
f9ee589c
...
@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
...
@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
):
'''set paiYarn configuration'''
pai_yarn_config_data
=
dict
()
pai_yarn_config_data
[
'pai_yarn_config'
]
=
experiment_config
[
'paiYarnConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
pai_yarn_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
):
def
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
):
'''set kubeflow configuration'''
'''set kubeflow configuration'''
kubeflow_config_data
=
dict
()
kubeflow_config_data
=
dict
()
...
@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'paiYarn'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'pai_yarn_config'
,
'value'
:
experiment_config
[
'paiYarnConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
...
@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
...
@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'pai'
:
elif
platform
==
'pai'
:
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'paiYarn'
:
config_result
,
err_msg
=
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'kubeflow'
:
elif
platform
==
'kubeflow'
:
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'frameworkcontroller'
:
elif
platform
==
'frameworkcontroller'
:
...
...
tools/nni_cmd/launcher_utils.py
View file @
f9ee589c
...
@@ -5,7 +5,7 @@ import os
...
@@ -5,7 +5,7 @@ import os
import
json
import
json
from
schema
import
SchemaError
from
schema
import
SchemaError
from
schema
import
Schema
from
schema
import
Schema
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
PAI_YARN_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
from
.common_utils
import
print_error
,
print_warning
,
print_normal
from
.common_utils
import
print_error
,
print_warning
,
print_normal
...
@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
...
@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def
validate_common_content
(
experiment_config
):
def
validate_common_content
(
experiment_config
):
'''Validate whether the common values in experiment_config is valid'''
'''Validate whether the common values in experiment_config is valid'''
if
not
experiment_config
.
get
(
'trainingServicePlatform'
)
or
\
if
not
experiment_config
.
get
(
'trainingServicePlatform'
)
or
\
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
]:
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
]:
print_error
(
'Please set correct trainingServicePlatform!'
)
print_error
(
'Please set correct trainingServicePlatform!'
)
exit
(
1
)
exit
(
1
)
schema_dict
=
{
schema_dict
=
{
'local'
:
LOCAL_CONFIG_SCHEMA
,
'local'
:
LOCAL_CONFIG_SCHEMA
,
'remote'
:
REMOTE_CONFIG_SCHEMA
,
'remote'
:
REMOTE_CONFIG_SCHEMA
,
'pai'
:
PAI_CONFIG_SCHEMA
,
'pai'
:
PAI_CONFIG_SCHEMA
,
'paiYarn'
:
PAI_YARN_CONFIG_SCHEMA
,
'kubeflow'
:
KUBEFLOW_CONFIG_SCHEMA
,
'kubeflow'
:
KUBEFLOW_CONFIG_SCHEMA
,
'frameworkcontroller'
:
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
'frameworkcontroller'
:
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
}
...
@@ -213,24 +214,18 @@ def validate_customized_file(experiment_config, spec_key):
...
@@ -213,24 +214,18 @@ def validate_customized_file(experiment_config, spec_key):
def
parse_tuner_content
(
experiment_config
):
def
parse_tuner_content
(
experiment_config
):
'''Validate whether tuner in experiment_config is valid'''
'''Validate whether tuner in experiment_config is valid'''
if
experiment_config
[
'tuner'
].
get
(
'builtinTunerName'
):
if
not
experiment_config
[
'tuner'
].
get
(
'builtinTunerName'
):
experiment_config
[
'tuner'
][
'className'
]
=
experiment_config
[
'tuner'
][
'builtinTunerName'
]
else
:
validate_customized_file
(
experiment_config
,
'tuner'
)
validate_customized_file
(
experiment_config
,
'tuner'
)
def
parse_assessor_content
(
experiment_config
):
def
parse_assessor_content
(
experiment_config
):
'''Validate whether assessor in experiment_config is valid'''
'''Validate whether assessor in experiment_config is valid'''
if
experiment_config
.
get
(
'assessor'
):
if
experiment_config
.
get
(
'assessor'
):
if
experiment_config
[
'assessor'
].
get
(
'builtinAssessorName'
):
if
not
experiment_config
[
'assessor'
].
get
(
'builtinAssessorName'
):
experiment_config
[
'assessor'
][
'className'
]
=
experiment_config
[
'assessor'
][
'builtinAssessorName'
]
else
:
validate_customized_file
(
experiment_config
,
'assessor'
)
validate_customized_file
(
experiment_config
,
'assessor'
)
def
parse_advisor_content
(
experiment_config
):
def
parse_advisor_content
(
experiment_config
):
'''Validate whether advisor in experiment_config is valid'''
'''Validate whether advisor in experiment_config is valid'''
if
experiment_config
[
'advisor'
].
get
(
'builtinAdvisorName'
):
if
not
experiment_config
[
'advisor'
].
get
(
'builtinAdvisorName'
):
experiment_config
[
'advisor'
][
'className'
]
=
experiment_config
[
'advisor'
][
'builtinAdvisorName'
]
else
:
validate_customized_file
(
experiment_config
,
'advisor'
)
validate_customized_file
(
experiment_config
,
'advisor'
)
def
validate_annotation_content
(
experiment_config
,
spec_key
,
builtin_name
):
def
validate_annotation_content
(
experiment_config
,
spec_key
,
builtin_name
):
...
@@ -261,7 +256,7 @@ def validate_machine_list(experiment_config):
...
@@ -261,7 +256,7 @@ def validate_machine_list(experiment_config):
def
validate_pai_trial_conifg
(
experiment_config
):
def
validate_pai_trial_conifg
(
experiment_config
):
'''validate the trial config in pai platform'''
'''validate the trial config in pai platform'''
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'pai'
:
if
experiment_config
.
get
(
'trainingServicePlatform'
)
in
[
'pai'
,
'paiYarn'
]
:
if
experiment_config
.
get
(
'trial'
).
get
(
'shmMB'
)
and
\
if
experiment_config
.
get
(
'trial'
).
get
(
'shmMB'
)
and
\
experiment_config
[
'trial'
][
'shmMB'
]
>
experiment_config
[
'trial'
][
'memoryMB'
]:
experiment_config
[
'trial'
][
'shmMB'
]
>
experiment_config
[
'trial'
][
'memoryMB'
]:
print_error
(
'shmMB should be no more than memoryMB!'
)
print_error
(
'shmMB should be no more than memoryMB!'
)
...
...
tools/nni_cmd/nnictl_utils.py
View file @
f9ee589c
...
@@ -682,10 +682,13 @@ def search_space_auto_gen(args):
...
@@ -682,10 +682,13 @@ def search_space_auto_gen(args):
trial_dir
=
os
.
path
.
expanduser
(
args
.
trial_dir
)
trial_dir
=
os
.
path
.
expanduser
(
args
.
trial_dir
)
file_path
=
os
.
path
.
expanduser
(
args
.
file
)
file_path
=
os
.
path
.
expanduser
(
args
.
file
)
if
not
os
.
path
.
isabs
(
file_path
):
if
not
os
.
path
.
isabs
(
file_path
):
abs_
file_path
=
os
.
path
.
join
(
os
.
getcwd
(),
file_path
)
file_path
=
os
.
path
.
join
(
os
.
getcwd
(),
file_path
)
assert
os
.
path
.
exists
(
trial_dir
)
assert
os
.
path
.
exists
(
trial_dir
)
if
os
.
path
.
exists
(
abs_
file_path
):
if
os
.
path
.
exists
(
file_path
):
print_warning
(
'%s already exits, will be over
written'
%
abs_
file_path
)
print_warning
(
'%s already exi
s
ts, will be overwritten
.
'
%
file_path
)
print_normal
(
'Dry run to generate search space...'
)
print_normal
(
'Dry run to generate search space...'
)
Popen
(
args
.
trial_command
,
cwd
=
trial_dir
,
env
=
dict
(
os
.
environ
,
NNI_GEN_SEARCH_SPACE
=
abs_file_path
),
shell
=
True
).
wait
()
Popen
(
args
.
trial_command
,
cwd
=
trial_dir
,
env
=
dict
(
os
.
environ
,
NNI_GEN_SEARCH_SPACE
=
file_path
),
shell
=
True
).
wait
()
print_normal
(
'Dry run to generate search space, Done'
)
if
not
os
.
path
.
exists
(
file_path
):
\ No newline at end of file
print_warning
(
'Expected search space file
\'
{}
\'
generated, but not found.'
.
format
(
file_path
))
else
:
print_normal
(
'Generate search space done:
\'
{}
\'
.'
.
format
(
file_path
))
tools/nni_trial_tool/trial_keeper.py
View file @
f9ee589c
...
@@ -223,7 +223,7 @@ if __name__ == '__main__':
...
@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit
(
1
)
exit
(
1
)
check_version
(
args
)
check_version
(
args
)
try
:
try
:
if
NNI_PLATFORM
==
'pai'
and
is_multi_phase
():
if
NNI_PLATFORM
==
'pai
Yarn
'
and
is_multi_phase
():
fetch_parameter_file
(
args
)
fetch_parameter_file
(
args
)
main_loop
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
except
SystemExit
as
se
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment