Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
24fa4619
Unverified
Commit
24fa4619
authored
Feb 19, 2020
by
QuanluZhang
Committed by
GitHub
Feb 19, 2020
Browse files
Merge pull request #2081 from microsoft/v1.4
merge V1.4 back to master
parents
aaaa2756
8ff039c2
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
121 additions
and
82 deletions
+121
-82
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+2
-0
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+5
-0
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+0
-2
src/sdk/pynni/nni/platform/test.py
src/sdk/pynni/nni/platform/test.py
+4
-1
src/sdk/pynni/nni/trial.py
src/sdk/pynni/nni/trial.py
+2
-2
src/sdk/pynni/tests/test_assessor.py
src/sdk/pynni/tests/test_assessor.py
+3
-3
src/sdk/pynni/tests/test_msg_dispatcher.py
src/sdk/pynni/tests/test_msg_dispatcher.py
+2
-2
src/webui/src/components/Modals/ChangeColumnComponent.tsx
src/webui/src/components/Modals/ChangeColumnComponent.tsx
+1
-2
src/webui/src/components/NavCon.tsx
src/webui/src/components/NavCon.tsx
+1
-1
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+2
-2
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+44
-30
src/webui/src/static/function.ts
src/webui/src/static/function.ts
+16
-11
src/webui/src/static/interface.ts
src/webui/src/static/interface.ts
+2
-1
src/webui/src/static/model/trial.ts
src/webui/src/static/model/trial.ts
+8
-3
src/webui/src/static/style/search.scss
src/webui/src/static/style/search.scss
+3
-2
src/webui/src/static/style/table.scss
src/webui/src/static/style/table.scss
+5
-1
test/metrics_test.py
test/metrics_test.py
+2
-2
test/pipelines-it-remote.yml
test/pipelines-it-remote.yml
+1
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+14
-12
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+4
-4
No files found.
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
24fa4619
...
...
@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase):
ValueError
Data type not supported
"""
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
loads
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
24fa4619
...
...
@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase):
"""Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
json_tricks
.
loads
(
entry
[
'value'
])
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
...
...
@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
loads
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
24fa4619
...
...
@@ -116,8 +116,6 @@ class AverageMeter:
n : int
The weight of the new value.
"""
if
not
isinstance
(
val
,
float
)
and
not
isinstance
(
val
,
int
):
_logger
.
warning
(
"Values passed to AverageMeter must be number, not %s."
,
type
(
val
))
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
...
...
src/sdk/pynni/nni/platform/test.py
View file @
24fa4619
...
...
@@ -33,4 +33,7 @@ def init_params(params):
_params
=
copy
.
deepcopy
(
params
)
def
get_last_metric
():
return
json_tricks
.
loads
(
_last_metric
)
metrics
=
json_tricks
.
loads
(
_last_metric
)
metrics
[
'value'
]
=
json_tricks
.
loads
(
metrics
[
'value'
])
return
metrics
src/sdk/pynni/nni/trial.py
View file @
24fa4619
...
...
@@ -114,7 +114,7 @@ def report_intermediate_result(metric):
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'PERIODICAL'
,
'sequence'
:
_intermediate_seq
,
'value'
:
metric
'value'
:
to_json
(
metric
)
})
_intermediate_seq
+=
1
platform
.
send_metric
(
metric
)
...
...
@@ -135,6 +135,6 @@ def report_final_result(metric):
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'FINAL'
,
'sequence'
:
0
,
'value'
:
metric
'value'
:
to_json
(
metric
)
})
platform
.
send_metric
(
metric
)
src/sdk/pynni/tests/test_assessor.py
View file @
24fa4619
...
...
@@ -47,9 +47,9 @@ def _restore_io():
class
AssessorTestCase
(
TestCase
):
def
test_assessor
(
self
):
_reverse_io
()
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":
2
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":
2
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":
3
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":
"2"
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":
"2"
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":
"3"
}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"A","event":"SYS_CANCELED"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"B","event":"SUCCEEDED"}'
)
send
(
CommandType
.
NewTrialJob
,
'null'
)
...
...
src/sdk/pynni/tests/test_msg_dispatcher.py
View file @
24fa4619
...
...
@@ -59,8 +59,8 @@ class MsgDispatcherTestCase(TestCase):
def
test_msg_dispatcher
(
self
):
_reverse_io
()
# now we are sending to Tuner's incoming stream
send
(
CommandType
.
RequestTrialJobs
,
'2'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":0,"type":"PERIODICAL","value":10}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":1,"type":"FINAL","value":11}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":0,"type":"PERIODICAL","value":
"
10
"
}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":1,"type":"FINAL","value":
"
11
"
}'
)
send
(
CommandType
.
UpdateSearchSpace
,
'{"name":"SS0"}'
)
send
(
CommandType
.
RequestTrialJobs
,
'1'
)
send
(
CommandType
.
KillTrialJob
,
'null'
)
...
...
src/webui/src/components/Modals/ChangeColumnComponent.tsx
View file @
24fa4619
...
...
@@ -117,7 +117,6 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol
});
return
(
<
div
>
<
div
>
Hello
</
div
>
<
Dialog
hidden
=
{
isHideDialog
}
// required field!
dialogContentProps
=
{
{
...
...
@@ -130,7 +129,7 @@ class ChangeColumnComponent extends React.Component<ChangeColumnProps, ChangeCol
styles
:
{
main
:
{
maxWidth
:
450
}
}
}
}
>
<
div
>
<
div
className
=
"columns-height"
>
{
renderOptions
.
map
(
item
=>
{
return
<
Checkbox
key
=
{
item
.
label
}
{
...
item
}
styles
=
{
{
root
:
{
marginBottom
:
8
}
}
}
/>
})
}
...
...
src/webui/src/components/NavCon.tsx
View file @
24fa4619
...
...
@@ -172,7 +172,7 @@ class NavCon extends React.Component<NavProps, NavState> {
/>
<
CommandBarButton
iconProps
=
{
infoIconAbout
}
text
=
"
a
bout"
text
=
"
A
bout"
menuProps
=
{
aboutProps
}
/>
</
Stack
>
...
...
src/webui/src/components/TrialsDetail.tsx
View file @
24fa4619
...
...
@@ -56,7 +56,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
return
;
}
switch
(
this
.
state
.
searchType
)
{
case
'
i
d
'
:
case
'
I
d
'
:
filter
=
(
trial
):
boolean
=>
trial
.
info
.
id
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
());
break
;
case
'
Trial No.
'
:
...
...
@@ -65,7 +65,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
case
'
Status
'
:
filter
=
(
trial
):
boolean
=>
trial
.
info
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
());
break
;
case
'
p
arameters
'
:
case
'
P
arameters
'
:
// TODO: support filters like `x: 2` (instead of `"x": 2`)
filter
=
(
trial
):
boolean
=>
JSON
.
stringify
(
trial
.
info
.
hyperParameters
,
null
,
4
).
includes
(
targetValue
);
break
;
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
24fa4619
...
...
@@ -54,7 +54,7 @@ interface TableListState {
isShowCustomizedModal
:
boolean
;
copyTrialId
:
string
;
// user copy trial to submit a new customized trial
isCalloutVisible
:
boolean
;
// kill job button callout [kill or not kill job window]
intermediateKey
s
:
string
[]
;
// intermeidate modal: which key is choosed.
intermediateKey
:
string
;
// intermeidate modal: which key is choosed.
isExpand
:
boolean
;
modalIntermediateWidth
:
number
;
modalIntermediateHeight
:
number
;
...
...
@@ -86,7 +86,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
isShowCustomizedModal
:
false
,
isCalloutVisible
:
false
,
copyTrialId
:
''
,
intermediateKey
s
:
[
'
default
'
]
,
intermediateKey
:
'
default
'
,
isExpand
:
false
,
modalIntermediateWidth
:
window
.
innerWidth
,
modalIntermediateHeight
:
window
.
innerHeight
,
...
...
@@ -128,7 +128,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
name
:
'
Default metric
'
,
className
:
'
leftTitle
'
,
key
:
'
accuracy
'
,
fieldName
:
'
a
ccuracy
'
,
fieldName
:
'
latestA
ccuracy
'
,
minWidth
:
200
,
maxWidth
:
300
,
isResizable
:
true
,
...
...
@@ -294,7 +294,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
intermediateId
);
// re-render
this
.
setState
({
intermediateKey
s
:
[
value
]
,
intermediateKey
:
value
,
intermediateOption
:
intermediate
});
}
...
...
@@ -388,29 +388,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
parameterStr
.
push
(
`
${
value
}
(search space)`
);
});
}
let
allColumnList
=
COLUMNPro
;
// eslint-disable-line @typescript-eslint/no-unused-vars
allColumnList
=
COLUMNPro
.
concat
(
parameterStr
);
let
allColumnList
=
COLUMNPro
.
concat
(
parameterStr
);
// only succeed trials have final keys
if
(
tableSource
.
filter
(
record
=>
record
.
status
===
'
SUCCEEDED
'
).
length
>=
1
)
{
const
temp
=
tableSource
.
filter
(
record
=>
record
.
status
===
'
SUCCEEDED
'
)[
0
].
acc
urac
y
;
const
temp
=
tableSource
.
filter
(
record
=>
record
.
status
===
'
SUCCEEDED
'
)[
0
].
acc
Dictionar
y
;
if
(
temp
!==
undefined
&&
typeof
temp
===
'
object
'
)
{
if
(
!
isNaN
(
temp
))
{
// concat default column and finalkeys
const
item
=
Object
.
keys
(
temp
);
// item: ['default', 'other-keys', 'maybe loss']
if
(
item
.
length
>
1
)
{
const
want
:
string
[]
=
[];
item
.
forEach
(
value
=>
{
if
(
value
!==
'
default
'
)
{
want
.
push
(
value
);
}
});
allColumnList
=
COLUMNPro
.
concat
(
want
);
}
// concat default column and finalkeys
const
item
=
Object
.
keys
(
temp
);
// item: ['default', 'other-keys', 'maybe loss']
if
(
item
.
length
>
1
)
{
const
want
:
string
[]
=
[];
item
.
forEach
(
value
=>
{
if
(
value
!==
'
default
'
)
{
want
.
push
(
value
);
}
});
allColumnList
=
allColumnList
.
concat
(
want
);
}
}
}
return
allColumnList
;
}
...
...
@@ -522,8 +520,22 @@ class TableList extends React.Component<TableListProps, TableListState> {
});
break
;
default
:
// FIXME
alert
(
'
Unexpected column type
'
);
showColumn
.
push
({
name
:
item
,
key
:
item
,
fieldName
:
item
,
minWidth
:
100
,
onRender
:
(
record
:
TableRecord
)
=>
{
const
accDictionary
=
record
.
accDictionary
;
let
other
=
''
;
if
(
accDictionary
!==
undefined
)
{
other
=
accDictionary
[
item
].
toString
();
}
return
(
<
div
>
{
other
}
</
div
>
);
}
});
}
}
return
showColumn
;
...
...
@@ -534,19 +546,22 @@ class TableList extends React.Component<TableListProps, TableListState> {
}
UNSAFE_componentWillReceiveProps
(
nextProps
:
TableListProps
):
void
{
const
{
columnList
}
=
nextProps
;
this
.
setState
({
tableColumns
:
this
.
initTableColumnList
(
columnList
)
});
const
{
columnList
,
tableSource
}
=
nextProps
;
this
.
setState
({
tableSourceForSort
:
tableSource
,
tableColumns
:
this
.
initTableColumnList
(
columnList
),
allColumnList
:
this
.
getAllColumnKeys
()
});
}
render
():
React
.
ReactNode
{
const
{
intermediateKey
s
,
modalIntermediateWidth
,
modalIntermediateHeight
,
const
{
intermediateKey
,
modalIntermediateWidth
,
modalIntermediateHeight
,
tableColumns
,
allColumnList
,
isShowColumn
,
modalVisible
,
selectRows
,
isShowCompareModal
,
intermediateOtherKeys
,
isShowCustomizedModal
,
copyTrialId
,
intermediateOption
}
=
this
.
state
;
const
{
columnList
}
=
this
.
props
;
const
tableSource
:
Array
<
TableRecord
>
=
JSON
.
parse
(
JSON
.
stringify
(
this
.
state
.
tableSourceForSort
));
return
(
<
Stack
>
<
div
id
=
"tableList"
>
...
...
@@ -580,11 +595,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
{
intermediateOtherKeys
.
length
>
1
?
<
Stack
className
=
"selectKeys"
styles
=
{
{
root
:
{
width
:
800
}
}
}
>
<
Stack
horizontalAlign
=
"end"
className
=
"selectKeys"
>
<
Dropdown
className
=
"select"
selectedKeys
=
{
intermediateKeys
}
onChange
=
{
this
.
selectOtherKeys
}
selectedKey
=
{
intermediateKey
}
options
=
{
intermediateOtherKeys
.
map
((
key
,
item
)
=>
{
return
{
...
...
@@ -592,7 +606,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
};
})
}
styles
=
{
{
dropdown
:
{
width
:
300
}
}
}
onChange
=
{
this
.
selectOtherKeys
}
/>
</
Stack
>
:
...
...
src/webui/src/static/function.ts
View file @
24fa4619
...
...
@@ -37,13 +37,21 @@ const convertDuration = (num: number): string => {
return
result
.
join
(
'
'
);
};
function
parseMetrics
(
metricData
:
string
):
any
{
if
(
metricData
.
includes
(
'
NaN
'
))
{
return
JSON5
.
parse
(
JSON5
.
parse
(
metricData
));
}
else
{
return
JSON
.
parse
(
JSON
.
parse
(
metricData
));
}
}
// get final result value
// draw Accuracy point graph
const
getFinalResult
=
(
final
?:
MetricDataRecord
[]):
number
=>
{
let
acc
;
let
showDefault
=
0
;
if
(
final
)
{
acc
=
JSON
.
parse
(
final
[
final
.
length
-
1
].
data
);
acc
=
parse
Metrics
(
final
[
final
.
length
-
1
].
data
);
if
(
typeof
(
acc
)
===
'
object
'
)
{
if
(
acc
.
default
)
{
showDefault
=
acc
.
default
;
...
...
@@ -61,7 +69,7 @@ const getFinalResult = (final?: MetricDataRecord[]): number => {
const
getFinal
=
(
final
?:
MetricDataRecord
[]):
FinalType
|
undefined
=>
{
let
showDefault
:
FinalType
;
if
(
final
)
{
showDefault
=
JSON
.
parse
(
final
[
final
.
length
-
1
].
data
);
showDefault
=
parse
Metrics
(
final
[
final
.
length
-
1
].
data
);
if
(
typeof
showDefault
===
'
number
'
)
{
showDefault
=
{
default
:
showDefault
};
}
...
...
@@ -179,17 +187,14 @@ function formatTimestamp(timestamp?: number, placeholder?: string): string {
return
timestamp
?
new
Date
(
timestamp
).
toLocaleString
(
'
en-US
'
)
:
placeholder
;
}
function
parseMetrics
(
metricData
:
string
):
any
{
if
(
metricData
.
includes
(
'
NaN
'
))
{
return
JSON5
.
parse
(
metricData
)
}
else
{
return
JSON
.
parse
(
metricData
)
}
}
function
metricAccuracy
(
metric
:
MetricDataRecord
):
number
{
const
data
=
parseMetrics
(
metric
.
data
);
return
typeof
data
===
'
number
'
?
data
:
NaN
;
// return typeof data === 'number' ? data : NaN;
if
(
typeof
data
===
'
number
'
)
{
return
data
;
}
else
{
return
data
.
default
;
}
}
function
formatAccuracy
(
accuracy
:
number
):
string
{
...
...
src/webui/src/static/interface.ts
View file @
24fa4619
...
...
@@ -23,7 +23,8 @@ interface TableRecord {
intermediateCount
:
number
;
accuracy
?:
number
;
latestAccuracy
:
number
|
undefined
;
formattedLatestAccuracy
:
string
;
// format (LATEST/FINAL)
formattedLatestAccuracy
:
string
;
// format (LATEST/FINAL),
accDictionary
:
FinalType
|
undefined
;
}
interface
SearchSpace
{
...
...
src/webui/src/static/model/trial.ts
View file @
24fa4619
...
...
@@ -53,10 +53,13 @@ class Trial implements TableObj {
if
(
this
.
accuracy
!==
undefined
)
{
return
this
.
accuracy
;
}
else
if
(
this
.
intermediates
.
length
>
0
)
{
// TODO: support intermeidate result is dict
const
temp
=
this
.
intermediates
[
this
.
intermediates
.
length
-
1
];
if
(
temp
!==
undefined
)
{
return
parseMetrics
(
temp
.
data
);
if
(
typeof
parseMetrics
(
temp
.
data
)
===
'
object
'
)
{
return
parseMetrics
(
temp
.
data
).
default
;
}
else
{
return
parseMetrics
(
temp
.
data
);
}
}
else
{
return
undefined
;
}
...
...
@@ -82,9 +85,11 @@ class Trial implements TableObj {
duration
,
status
:
this
.
info
.
status
,
intermediateCount
:
this
.
intermediates
.
length
,
accuracy
:
this
.
finalAcc
,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
accuracy
:
this
.
acc
!==
undefined
?
JSON
.
parse
(
this
.
acc
!
.
default
)
:
undefined
,
latestAccuracy
:
this
.
latestAccuracy
,
formattedLatestAccuracy
:
this
.
formatLatestAccuracy
(),
accDictionary
:
this
.
acc
};
}
...
...
src/webui/src/static/style/search.scss
View file @
24fa4619
...
...
@@ -27,9 +27,10 @@
.selectKeys
{
/* intermediate result is dict, select box for keys */
.select
{
margin-right
:
12%
;
}
.ms-Dropdown
{
width
:
120px
;
float
:
right
;
margin-right
:
10%
;
}
}
src/webui/src/static/style/table.scss
View file @
24fa4619
...
...
@@ -46,4 +46,8 @@
}
.detail-table
{
padding
:
5px
0
0
0
;
}
\ No newline at end of file
}
.columns-height
{
max-height
:
335px
;
overflow-y
:
scroll
;
}
test/metrics_test.py
View file @
24fa4619
...
...
@@ -56,9 +56,9 @@ def get_metric_results(metrics):
final_result
=
[]
for
metric
in
metrics
:
if
metric
[
'type'
]
==
'PERIODICAL'
:
intermediate_result
.
append
(
metric
[
'data'
])
intermediate_result
.
append
(
json
.
loads
(
metric
[
'data'
])
)
elif
metric
[
'type'
]
==
'FINAL'
:
final_result
.
append
(
metric
[
'data'
])
final_result
.
append
(
json
.
loads
(
metric
[
'data'
])
)
print
(
intermediate_result
,
final_result
)
return
[
round
(
float
(
x
),
6
)
for
x
in
intermediate_result
],
[
round
(
float
(
x
),
6
)
for
x
in
final_result
]
...
...
test/pipelines-it-remote.yml
View file @
24fa4619
...
...
@@ -59,7 +59,7 @@ jobs:
displayName
:
'
integration
test'
-
task
:
SSH@0
inputs
:
sshEndpoint
:
remote_nni-ci-gpu-01
sshEndpoint
:
$(end_point)
runOptions
:
commands
commands
:
python3 /tmp/nnitest/$(Build.BuildId)/test/remote_docker.py --mode stop --name $(Build.BuildId)
displayName
:
'
Stop
docker'
tools/nni_cmd/launcher.py
View file @
24fa4619
...
...
@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error
(
'Fail to find nni under python library'
)
exit
(
1
)
def
start_rest_server
(
args
,
platform
,
mode
,
config_file_name
,
experiment_id
=
None
,
log_dir
=
None
,
log_level
=
None
):
def
start_rest_server
(
port
,
platform
,
mode
,
config_file_name
,
foreground
=
False
,
experiment_id
=
None
,
log_dir
=
None
,
log_level
=
None
):
'''Run nni manager process'''
if
detect_port
(
args
.
port
):
if
detect_port
(
port
):
print_error
(
'Port %s is used by another process, please reset the port!
\n
'
\
'You could use
\'
nnictl create --help
\'
to get help information'
%
args
.
port
)
'You could use
\'
nnictl create --help
\'
to get help information'
%
port
)
exit
(
1
)
if
(
platform
!=
'local'
)
and
detect_port
(
int
(
args
.
port
)
+
1
):
if
(
platform
!=
'local'
)
and
detect_port
(
int
(
port
)
+
1
):
print_error
(
'PAI mode need an additional adjacent port %d, and the port %d is used by another process!
\n
'
\
'You could set another port to start experiment!
\n
'
\
'You could use
\'
nnictl create --help
\'
to get help information'
%
((
int
(
args
.
port
)
+
1
),
(
int
(
args
.
port
)
+
1
)))
'You could use
\'
nnictl create --help
\'
to get help information'
%
((
int
(
port
)
+
1
),
(
int
(
port
)
+
1
)))
exit
(
1
)
print_normal
(
'Starting restful server...'
)
...
...
@@ -99,7 +99,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
node_command
=
'node'
if
sys
.
platform
==
'win32'
:
node_command
=
os
.
path
.
join
(
entry_dir
[:
-
3
],
'Scripts'
,
'node.exe'
)
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
args
.
port
),
'--mode'
,
platform
]
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
]
if
mode
==
'view'
:
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--readonly'
,
'true'
]
...
...
@@ -111,7 +111,7 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
cmds
+=
[
'--log_level'
,
log_level
]
if
mode
in
[
'resume'
,
'view'
]:
cmds
+=
[
'--experiment_id'
,
experiment_id
]
if
args
.
foreground
:
if
foreground
:
cmds
+=
[
'--foreground'
,
'true'
]
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
...
...
@@ -122,12 +122,12 @@ def start_rest_server(args, platform, mode, config_file_name, experiment_id=None
stderr_file
.
write
(
log_header
)
if
sys
.
platform
==
'win32'
:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
if
args
.
foreground
:
if
foreground
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
PIPE
,
stderr
=
STDOUT
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
if
args
.
foreground
:
if
foreground
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
PIPE
,
stderr
=
PIPE
)
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
)
...
...
@@ -428,12 +428,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground
=
False
if
mode
!=
'view'
:
foreground
=
args
.
foreground
if
log_level
not
in
[
'trace'
,
'debug'
]
and
(
args
.
debug
or
experiment_config
.
get
(
'debug'
)
is
True
):
log_level
=
'debug'
# start rest server
rest_process
,
start_time
=
start_rest_server
(
args
,
experiment_config
[
'trainingServicePlatform'
],
\
mode
,
config_file_name
,
experiment_id
,
log_dir
,
log_level
)
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
\
mode
,
config_file_name
,
foreground
,
experiment_id
,
log_dir
,
log_level
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
# Deal with annotation
if
experiment_config
.
get
(
'useAnnotation'
):
...
...
@@ -501,7 +503,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config
[
'experimentName'
])
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
if
args
.
foreground
:
if
mode
!=
'view'
and
args
.
foreground
:
try
:
while
True
:
log_content
=
rest_process
.
stdout
.
readline
().
strip
().
decode
(
'utf-8'
)
...
...
tools/nni_cmd/nnictl.py
View file @
24fa4619
...
...
@@ -63,10 +63,10 @@ def parse_args():
parser_resume
.
set_defaults
(
func
=
resume_experiment
)
# parse view command
parser_
resume
=
subparsers
.
add_parser
(
'view'
,
help
=
'view a stopped experiment'
)
parser_
resume
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to view'
)
parser_
resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_
resume
.
set_defaults
(
func
=
view_experiment
)
parser_
view
=
subparsers
.
add_parser
(
'view'
,
help
=
'view a stopped experiment'
)
parser_
view
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to view'
)
parser_
view
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_
view
.
set_defaults
(
func
=
view_experiment
)
# parse update command
parser_updater
=
subparsers
.
add_parser
(
'update'
,
help
=
'update the experiment'
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment