Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1500458a
Unverified
Commit
1500458a
authored
Jun 24, 2019
by
SparkSnail
Committed by
GitHub
Jun 24, 2019
Browse files
Merge pull request #187 from microsoft/master
merge master
parents
93dd76ba
97829ccd
Changes
57
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
375 additions
and
43 deletions
+375
-43
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+118
-5
src/webui/src/static/interface.ts
src/webui/src/static/interface.ts
+10
-5
src/webui/src/static/style/compare.scss
src/webui/src/static/style/compare.scss
+19
-0
src/webui/src/static/style/search.scss
src/webui/src/static/style/search.scss
+14
-1
src/webui/src/static/style/table.scss
src/webui/src/static/style/table.scss
+5
-0
test/async_sharing_test/simple_tuner.py
test/async_sharing_test/simple_tuner.py
+2
-2
test/config_test/multi_phase/multi_phase.test.yml
test/config_test/multi_phase/multi_phase.test.yml
+3
-3
test/config_test/multi_phase/multi_phase_batch.test.yml
test/config_test/multi_phase/multi_phase_batch.test.yml
+20
-0
test/config_test/multi_phase/multi_phase_evolution.test.yml
test/config_test/multi_phase/multi_phase_evolution.test.yml
+22
-0
test/config_test/multi_phase/multi_phase_grid.test.yml
test/config_test/multi_phase/multi_phase_grid.test.yml
+20
-0
test/config_test/multi_phase/multi_phase_metis.test.yml
test/config_test/multi_phase/multi_phase_metis.test.yml
+22
-0
test/config_test/multi_phase/multi_phase_tpe.test.yml
test/config_test/multi_phase/multi_phase_tpe.test.yml
+22
-0
test/config_test/multi_thread/multi_thread_tuner.py
test/config_test/multi_thread/multi_thread_tuner.py
+2
-2
test/naive_test/naive_tuner.py
test/naive_test/naive_tuner.py
+2
-2
tools/nni_trial_tool/constants.py
tools/nni_trial_tool/constants.py
+3
-1
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+85
-20
tools/nni_trial_tool/url_utils.py
tools/nni_trial_tool/url_utils.py
+6
-2
No files found.
src/webui/src/components/trial-detail/TableList.tsx
View file @
1500458a
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
axios
from
'
axios
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
Row
,
Table
,
Button
,
Popconfirm
,
Modal
,
Checkbox
}
from
'
antd
'
;
import
{
Row
,
Table
,
Button
,
Popconfirm
,
Modal
,
Checkbox
,
Select
}
from
'
antd
'
;
const
Option
=
Select
.
Option
;
const
CheckboxGroup
=
Checkbox
.
Group
;
const
CheckboxGroup
=
Checkbox
.
Group
;
import
{
MANAGER_IP
,
trialJobStatus
,
COLUMN
,
COLUMN_INDEX
}
from
'
../../static/const
'
;
import
{
MANAGER_IP
,
trialJobStatus
,
COLUMN
,
COLUMN_INDEX
}
from
'
../../static/const
'
;
import
{
convertDuration
,
intermediateGraphOption
,
killJob
}
from
'
../../static/function
'
;
import
{
convertDuration
,
intermediateGraphOption
,
killJob
}
from
'
../../static/function
'
;
import
{
TableObj
,
TrialJob
}
from
'
../../static/interface
'
;
import
{
TableObj
,
TrialJob
}
from
'
../../static/interface
'
;
import
OpenRow
from
'
../public-child/OpenRow
'
;
import
OpenRow
from
'
../public-child/OpenRow
'
;
import
Compare
from
'
../Modal/Compare
'
;
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
// table default metric column
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
// table default metric column
import
'
../../static/style/search.scss
'
;
import
'
../../static/style/search.scss
'
;
require
(
'
../../static/style/tableStatus.css
'
);
require
(
'
../../static/style/tableStatus.css
'
);
...
@@ -38,6 +40,12 @@ interface TableListState {
...
@@ -38,6 +40,12 @@ interface TableListState {
isObjFinal
:
boolean
;
isObjFinal
:
boolean
;
isShowColumn
:
boolean
;
isShowColumn
:
boolean
;
columnSelected
:
Array
<
string
>
;
// user select columnKeys
columnSelected
:
Array
<
string
>
;
// user select columnKeys
selectRows
:
Array
<
TableObj
>
;
isShowCompareModal
:
boolean
;
selectedRowKeys
:
string
[]
|
number
[];
intermediateData
:
Array
<
object
>
;
// a trial's intermediate results (include dict)
intermediateId
:
string
;
intermediateOtherKeys
:
Array
<
string
>
;
}
}
interface
ColumnIndex
{
interface
ColumnIndex
{
...
@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
public
_isMounted
=
false
;
public
_isMounted
=
false
;
public
intervalTrialLog
=
10
;
public
intervalTrialLog
=
10
;
public
_trialId
:
string
;
public
_trialId
:
string
;
public
tables
:
Table
<
TableObj
>
|
null
;
constructor
(
props
:
TableListProps
)
{
constructor
(
props
:
TableListProps
)
{
super
(
props
);
super
(
props
);
...
@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
modalVisible
:
false
,
modalVisible
:
false
,
isObjFinal
:
false
,
isObjFinal
:
false
,
isShowColumn
:
false
,
isShowColumn
:
false
,
columnSelected
:
COLUMN
isShowCompareModal
:
false
,
columnSelected
:
COLUMN
,
selectRows
:
[],
selectedRowKeys
:
[],
// close selected trial message after modal closed
intermediateData
:
[],
intermediateId
:
''
,
intermediateOtherKeys
:
[]
};
};
}
}
...
@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
.
then
(
res
=>
{
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
if
(
res
.
status
===
200
)
{
const
intermediateArr
:
number
[]
=
[];
const
intermediateArr
:
number
[]
=
[];
// support intermediate result is dict
// support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array
let
otherkeys
:
Array
<
string
>
=
[
'
default
'
];
if
(
res
.
data
.
length
!==
0
)
{
otherkeys
=
Object
.
keys
(
JSON
.
parse
(
res
.
data
[
0
].
data
));
}
// intermediateArr just store default val
Object
.
keys
(
res
.
data
).
map
(
item
=>
{
Object
.
keys
(
res
.
data
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
res
.
data
[
item
].
data
);
const
temp
=
JSON
.
parse
(
res
.
data
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
if
(
typeof
temp
===
'
object
'
)
{
...
@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
id
);
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
id
);
if
(
this
.
_isMounted
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
this
.
setState
(()
=>
({
intermediateOption
:
intermediate
intermediateData
:
res
.
data
,
// store origin intermediate data for a trial
intermediateOption
:
intermediate
,
intermediateOtherKeys
:
otherkeys
,
intermediateId
:
id
}));
}));
}
}
}
}
...
@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
}
}
}
}
selectOtherKeys
=
(
value
:
string
)
=>
{
const
isShowDefault
:
boolean
=
value
===
'
default
'
?
true
:
false
;
const
{
intermediateData
,
intermediateId
}
=
this
.
state
;
const
intermediateArr
:
number
[]
=
[];
// just watch default key-val
if
(
isShowDefault
===
true
)
{
Object
.
keys
(
intermediateData
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
intermediateData
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
intermediateArr
.
push
(
temp
[
value
]);
}
else
{
intermediateArr
.
push
(
temp
);
}
});
}
else
{
Object
.
keys
(
intermediateData
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
intermediateData
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
intermediateArr
.
push
(
temp
[
value
]);
}
});
}
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
intermediateId
);
// re-render
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
intermediateOption
:
intermediate
}));
}
}
hideIntermediateModal
=
()
=>
{
hideIntermediateModal
=
()
=>
{
if
(
this
.
_isMounted
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
this
.
setState
({
...
@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
);
);
}
}
fillSelectedRowsTostate
=
(
selected
:
number
[]
|
string
[],
selectedRows
:
Array
<
TableObj
>
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
selectRows
:
selectedRows
,
selectedRowKeys
:
selected
}));
}
}
// open Compare-modal
compareBtn
=
()
=>
{
const
{
selectRows
}
=
this
.
state
;
if
(
selectRows
.
length
===
0
)
{
alert
(
'
Please select datas you want to compare!
'
);
}
else
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
({
isShowCompareModal
:
true
});
}
}
}
// close Compare-modal
hideCompareModal
=
()
=>
{
// close modal. clear select rows data, clear selected track
if
(
this
.
_isMounted
)
{
this
.
setState
({
isShowCompareModal
:
false
,
selectedRowKeys
:
[],
selectRows
:
[]
});
}
}
componentDidMount
()
{
componentDidMount
()
{
this
.
_isMounted
=
true
;
this
.
_isMounted
=
true
;
}
}
...
@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
render
()
{
render
()
{
const
{
entries
,
tableSource
,
updateList
}
=
this
.
props
;
const
{
entries
,
tableSource
,
updateList
}
=
this
.
props
;
const
{
intermediateOption
,
modalVisible
,
isShowColumn
,
columnSelected
}
=
this
.
state
;
const
{
intermediateOption
,
modalVisible
,
isShowColumn
,
columnSelected
,
selectRows
,
isShowCompareModal
,
selectedRowKeys
,
intermediateOtherKeys
}
=
this
.
state
;
const
rowSelection
=
{
selectedRowKeys
:
selectedRowKeys
,
onChange
:
(
selected
:
string
[]
|
number
[],
selectedRows
:
Array
<
TableObj
>
)
=>
{
this
.
fillSelectedRowsTostate
(
selected
,
selectedRows
);
}
};
let
showTitle
=
COLUMN
;
let
showTitle
=
COLUMN
;
let
bgColor
=
''
;
let
bgColor
=
''
;
const
trialJob
:
Array
<
TrialJob
>
=
[];
const
trialJob
:
Array
<
TrialJob
>
=
[];
...
@@ -417,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -417,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
<
Row
className
=
"tableList"
>
<
Row
className
=
"tableList"
>
<
div
id
=
"tableList"
>
<
div
id
=
"tableList"
>
<
Table
<
Table
ref
=
{
(
table
:
Table
<
TableObj
>
|
null
)
=>
this
.
tables
=
table
}
columns
=
{
showColumn
}
columns
=
{
showColumn
}
rowSelection
=
{
rowSelection
}
expandedRowRender
=
{
this
.
openRow
}
expandedRowRender
=
{
this
.
openRow
}
dataSource
=
{
tableSource
}
dataSource
=
{
tableSource
}
className
=
"commonTableStyle"
className
=
"commonTableStyle"
...
@@ -432,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -432,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
destroyOnClose
=
{
true
}
destroyOnClose
=
{
true
}
width
=
"80%"
width
=
"80%"
>
>
{
intermediateOtherKeys
.
length
>
1
?
<
Row
className
=
"selectKeys"
>
<
Select
className
=
"select"
defaultValue
=
"default"
onSelect
=
{
this
.
selectOtherKeys
}
>
{
Object
.
keys
(
intermediateOtherKeys
).
map
(
item
=>
{
const
keys
=
intermediateOtherKeys
[
item
];
return
<
Option
value
=
{
keys
}
key
=
{
item
}
>
{
keys
}
</
Option
>;
})
}
</
Select
>
</
Row
>
:
<
div
/>
}
<
ReactEcharts
<
ReactEcharts
option
=
{
intermediateOption
}
option
=
{
intermediateOption
}
style
=
{
{
style
=
{
{
...
@@ -458,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -458,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
className
=
"titleColumn"
className
=
"titleColumn"
/>
/>
</
Modal
>
</
Modal
>
<
Compare
compareRows
=
{
selectRows
}
visible
=
{
isShowCompareModal
}
cancelFunc
=
{
this
.
hideCompareModal
}
/>
</
Row
>
</
Row
>
);
);
}
}
...
...
src/webui/src/static/interface.ts
View file @
1500458a
...
@@ -108,10 +108,15 @@ interface FinalResult {
...
@@ -108,10 +108,15 @@ interface FinalResult {
data
:
string
;
data
:
string
;
}
}
interface
Intermedia
{
name
:
string
;
// id
type
:
string
;
data
:
Array
<
number
|
object
>
;
// intermediate data
hyperPara
:
object
;
// each trial hyperpara value
}
export
{
export
{
TableObj
,
Parameters
,
Experiment
,
TableObj
,
Parameters
,
Experiment
,
AccurPoint
,
TrialNumber
,
TrialJob
,
AccurPoint
,
TrialNumber
,
TrialJob
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalResult
,
FinalType
,
DetailAccurPoint
,
TooltipForAccuracy
,
TooltipForIntermediate
,
SearchSpace
,
Intermedia
ParaObj
,
Dimobj
,
FinalResult
,
FinalType
,
TooltipForIntermediate
,
SearchSpace
};
};
src/webui/src/static/style/compare.scss
0 → 100644
View file @
1500458a
.compare
{
width
:
92%
;
margin
:
0
auto
;
color
:
#333
;
tr
{
line-height
:
30px
;
border-bottom
:
1px
solid
#ccc
;
}
.column
{
width
:
124px
;
padding-left
:
18px
;
font-weight
:
700
;
}
.value
{
width
:
152px
;
padding-right
:
18px
;
text-align
:
right
;
}
}
src/webui/src/static/style/search.scss
View file @
1500458a
/* some buttons
about
trial-detail table */
/* some buttons
in
trial-detail table */
.allList
{
.allList
{
width
:
96%
;
width
:
96%
;
margin
:
0
auto
;
margin
:
0
auto
;
...
@@ -31,4 +31,17 @@
...
@@ -31,4 +31,17 @@
}
}
}
}
Button
.mediateBtn
{
margin
:
0
32px
;
}
/* each row's Intermediate btn -> Modal*/
.selectKeys
{
/* intermediate result is dict, select box for keys */
.select
{
width
:
120px
;
float
:
right
;
margin-right
:
10%
;
}
}
src/webui/src/static/style/table.scss
View file @
1500458a
...
@@ -102,3 +102,8 @@
...
@@ -102,3 +102,8 @@
.ant-modal-title
{
.ant-modal-title
{
font-size
:
20px
;
font-size
:
20px
;
}
}
/*disable select all checkbox in detail page*/
.ant-table-selection
{
display
:
none
;
}
test/async_sharing_test/simple_tuner.py
View file @
1500458a
...
@@ -22,7 +22,7 @@ class SimpleTuner(Tuner):
...
@@ -22,7 +22,7 @@ class SimpleTuner(Tuner):
self
.
sig_event
=
Event
()
self
.
sig_event
=
Event
()
self
.
thread_lock
=
Lock
()
self
.
thread_lock
=
Lock
()
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
if
self
.
f_id
is
None
:
if
self
.
f_id
is
None
:
self
.
thread_lock
.
acquire
()
self
.
thread_lock
.
acquire
()
self
.
f_id
=
parameter_id
self
.
f_id
=
parameter_id
...
@@ -50,7 +50,7 @@ class SimpleTuner(Tuner):
...
@@ -50,7 +50,7 @@ class SimpleTuner(Tuner):
self
.
thread_lock
.
release
()
self
.
thread_lock
.
release
()
return
self
.
trial_meta
[
parameter_id
]
return
self
.
trial_meta
[
parameter_id
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
reward
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
reward
,
**
kwargs
):
self
.
thread_lock
.
acquire
()
self
.
thread_lock
.
acquire
()
if
parameter_id
==
self
.
f_id
:
if
parameter_id
==
self
.
f_id
:
self
.
trial_meta
[
parameter_id
][
'checksum'
]
=
reward
[
'checksum'
]
self
.
trial_meta
[
parameter_id
][
'checksum'
]
=
reward
[
'checksum'
]
...
...
test/config_test/multi_phase/multi_phase.test.yml
View file @
1500458a
...
@@ -6,9 +6,9 @@ trialConcurrency: 4
...
@@ -6,9 +6,9 @@ trialConcurrency: 4
searchSpacePath
:
./search_space.json
searchSpacePath
:
./search_space.json
tuner
:
tuner
:
codeDir
:
../../../src/sdk/pynni/tests
builtinTunerName
:
TPE
class
FileName
:
test_multi_phase_tuner.py
class
Args
:
className
:
NaiveMultiPhaseTuner
optimize_mode
:
maximize
trial
:
trial
:
codeDir
:
.
codeDir
:
.
...
...
test/config_test/multi_phase/multi_phase_batch.test.yml
0 → 100644
View file @
1500458a
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
5m
maxTrialNum
:
8
trialConcurrency
:
4
searchSpacePath
:
./search_space.json
tuner
:
builtinTunerName
:
BatchTuner
trial
:
codeDir
:
.
command
:
python3 multi_phase.py
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
true
multiThread
:
false
trainingServicePlatform
:
local
test/config_test/multi_phase/multi_phase_evolution.test.yml
0 → 100644
View file @
1500458a
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
5m
maxTrialNum
:
8
trialConcurrency
:
4
searchSpacePath
:
./search_space.json
tuner
:
builtinTunerName
:
Evolution
classArgs
:
optimize_mode
:
maximize
trial
:
codeDir
:
.
command
:
python3 multi_phase.py
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
true
multiThread
:
false
trainingServicePlatform
:
local
test/config_test/multi_phase/multi_phase_grid.test.yml
0 → 100644
View file @
1500458a
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
5m
maxTrialNum
:
8
trialConcurrency
:
4
searchSpacePath
:
./search_space.json
tuner
:
builtinTunerName
:
GridSearch
trial
:
codeDir
:
.
command
:
python3 multi_phase.py
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
true
multiThread
:
false
trainingServicePlatform
:
local
test/config_test/multi_phase/multi_phase_metis.test.yml
0 → 100644
View file @
1500458a
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
5m
maxTrialNum
:
8
trialConcurrency
:
4
searchSpacePath
:
./search_space.json
tuner
:
builtinTunerName
:
MetisTuner
classArgs
:
optimize_mode
:
maximize
trial
:
codeDir
:
.
command
:
python3 multi_phase.py
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
true
multiThread
:
false
trainingServicePlatform
:
local
test/config_test/multi_phase/multi_phase_tpe.test.yml
0 → 100644
View file @
1500458a
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
5m
maxTrialNum
:
8
trialConcurrency
:
4
searchSpacePath
:
./search_space.json
tuner
:
builtinTunerName
:
TPE
classArgs
:
optimize_mode
:
maximize
trial
:
codeDir
:
.
command
:
python3 multi_phase.py
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
true
multiThread
:
false
trainingServicePlatform
:
local
test/config_test/multi_thread/multi_thread_tuner.py
View file @
1500458a
...
@@ -6,7 +6,7 @@ class MultiThreadTuner(Tuner):
...
@@ -6,7 +6,7 @@ class MultiThreadTuner(Tuner):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
parent_done
=
False
self
.
parent_done
=
False
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
if
parameter_id
==
0
:
if
parameter_id
==
0
:
return
{
'x'
:
0
}
return
{
'x'
:
0
}
else
:
else
:
...
@@ -14,7 +14,7 @@ class MultiThreadTuner(Tuner):
...
@@ -14,7 +14,7 @@ class MultiThreadTuner(Tuner):
time
.
sleep
(
2
)
time
.
sleep
(
2
)
return
{
'x'
:
1
}
return
{
'x'
:
1
}
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
if
parameter_id
==
0
:
if
parameter_id
==
0
:
self
.
parent_done
=
True
self
.
parent_done
=
True
...
...
test/naive_test/naive_tuner.py
View file @
1500458a
...
@@ -16,12 +16,12 @@ class NaiveTuner(Tuner):
...
@@ -16,12 +16,12 @@ class NaiveTuner(Tuner):
self
.
cur
=
0
self
.
cur
=
0
_logger
.
info
(
'init'
)
_logger
.
info
(
'init'
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
self
.
cur
+=
1
self
.
cur
+=
1
_logger
.
info
(
'generate parameters: %s'
%
self
.
cur
)
_logger
.
info
(
'generate parameters: %s'
%
self
.
cur
)
return
{
'x'
:
self
.
cur
}
return
{
'x'
:
self
.
cur
}
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
reward
=
extract_scalar_reward
(
value
)
reward
=
extract_scalar_reward
(
value
)
_logger
.
info
(
'receive trial result: %s, %s, %s'
%
(
parameter_id
,
parameters
,
reward
))
_logger
.
info
(
'receive trial result: %s, %s, %s'
%
(
parameter_id
,
parameters
,
reward
))
_result
.
write
(
'%d %d
\n
'
%
(
parameters
[
'x'
],
reward
))
_result
.
write
(
'%d %d
\n
'
%
(
parameters
[
'x'
],
reward
))
...
...
tools/nni_trial_tool/constants.py
View file @
1500458a
...
@@ -36,6 +36,8 @@ STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
...
@@ -36,6 +36,8 @@ STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
STDOUT_API
=
'/stdout'
STDOUT_API
=
'/stdout'
VERSION_API
=
'/version'
VERSION_API
=
'/version'
PARAMETER_META_API
=
'/parameter-file-meta'
NNI_SYS_DIR
=
os
.
environ
[
'NNI_SYS_DIR'
]
NNI_SYS_DIR
=
os
.
environ
[
'NNI_SYS_DIR'
]
NNI_TRIAL_JOB_ID
=
os
.
environ
[
'NNI_TRIAL_JOB_ID'
]
NNI_TRIAL_JOB_ID
=
os
.
environ
[
'NNI_TRIAL_JOB_ID'
]
NNI_EXP_ID
=
os
.
environ
[
'NNI_EXP_ID'
]
NNI_EXP_ID
=
os
.
environ
[
'NNI_EXP_ID'
]
\ No newline at end of file
MULTI_PHASE
=
os
.
environ
[
'MULTI_PHASE'
]
tools/nni_trial_tool/trial_keeper.py
View file @
1500458a
...
@@ -28,30 +28,27 @@ import re
...
@@ -28,30 +28,27 @@ import re
import
sys
import
sys
import
select
import
select
import
json
import
json
import
threading
from
pyhdfs
import
HdfsClient
from
pyhdfs
import
HdfsClient
import
pkg_resources
import
pkg_resources
from
.rest_utils
import
rest_post
from
.rest_utils
import
rest_post
,
rest_get
from
.url_utils
import
gen_send_stdout_url
,
gen_send_version_url
from
.url_utils
import
gen_send_stdout_url
,
gen_send_version_url
,
gen_parameter_meta_url
from
.constants
import
HOME_DIR
,
LOG_DIR
,
NNI_PLATFORM
,
STDOUT_FULL_PATH
,
STDERR_FULL_PATH
from
.constants
import
HOME_DIR
,
LOG_DIR
,
NNI_PLATFORM
,
STDOUT_FULL_PATH
,
STDERR_FULL_PATH
,
\
from
.hdfsClientUtility
import
copyDirectoryToHdfs
,
copyHdfsDirectoryToLocal
MULTI_PHASE
,
NNI_TRIAL_JOB_ID
,
NNI_SYS_DIR
,
NNI_EXP_ID
from
.hdfsClientUtility
import
copyDirectoryToHdfs
,
copyHdfsDirectoryToLocal
,
copyHdfsFileToLocal
from
.log_utils
import
LogType
,
nni_log
,
RemoteLogger
,
PipeLogReader
,
StdOutputType
from
.log_utils
import
LogType
,
nni_log
,
RemoteLogger
,
PipeLogReader
,
StdOutputType
logger
=
logging
.
getLogger
(
'trial_keeper'
)
logger
=
logging
.
getLogger
(
'trial_keeper'
)
regular
=
re
.
compile
(
'v?(?P<version>[0-9](\.[0-9]){0,1}).*'
)
regular
=
re
.
compile
(
'v?(?P<version>[0-9](\.[0-9]){0,1}).*'
)
def
main_loop
(
args
):
_hdfs_client
=
None
'''main loop logic for trial keeper'''
if
not
os
.
path
.
exists
(
LOG_DIR
):
def
get_hdfs_client
(
args
):
os
.
makedirs
(
LOG_DIR
)
global
_hdfs_client
stdout_file
=
open
(
STDOUT_FULL_PATH
,
'a+'
)
if
_hdfs_client
is
not
None
:
stderr_file
=
open
(
STDERR_FULL_PATH
,
'a+'
)
return
_hdfs_client
trial_keeper_syslogger
=
RemoteLogger
(
args
.
nnimanager_ip
,
args
.
nnimanager_port
,
'trial_keeper'
,
StdOutputType
.
Stdout
,
args
.
log_collection
)
# redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout
=
RemoteLogger
(
args
.
nnimanager_ip
,
args
.
nnimanager_port
,
'trial'
,
StdOutputType
.
Stdout
,
args
.
log_collection
)
sys
.
stdout
=
sys
.
stderr
=
trial_keeper_syslogger
# backward compatibility
# backward compatibility
hdfs_host
=
None
hdfs_host
=
None
hdfs_output_dir
=
None
hdfs_output_dir
=
None
...
@@ -59,21 +56,41 @@ def main_loop(args):
...
@@ -59,21 +56,41 @@ def main_loop(args):
hdfs_host
=
args
.
hdfs_host
hdfs_host
=
args
.
hdfs_host
elif
args
.
pai_hdfs_host
:
elif
args
.
pai_hdfs_host
:
hdfs_host
=
args
.
pai_hdfs_host
hdfs_host
=
args
.
pai_hdfs_host
if
args
.
hdfs_output_dir
:
else
:
hdfs_output_dir
=
args
.
hdfs_output_dir
return
None
elif
args
.
pai_hdfs_output_dir
:
hdfs_output_dir
=
args
.
pai_hdfs_output_dir
if
hdfs_host
is
not
None
and
args
.
nni_hdfs_exp_dir
is
not
None
:
if
hdfs_host
is
not
None
and
args
.
nni_hdfs_exp_dir
is
not
None
:
try
:
try
:
if
args
.
webhdfs_path
:
if
args
.
webhdfs_path
:
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
hdfs_host
),
user_name
=
args
.
pai_user_name
,
webhdfs_path
=
args
.
webhdfs_path
,
timeout
=
5
)
_
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
hdfs_host
),
user_name
=
args
.
pai_user_name
,
webhdfs_path
=
args
.
webhdfs_path
,
timeout
=
5
)
else
:
else
:
# backward compatibility
# backward compatibility
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:{1}'
.
format
(
hdfs_host
,
'50070'
),
user_name
=
args
.
pai_user_name
,
timeout
=
5
)
_
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:{1}'
.
format
(
hdfs_host
,
'50070'
),
user_name
=
args
.
pai_user_name
,
timeout
=
5
)
except
Exception
as
e
:
except
Exception
as
e
:
nni_log
(
LogType
.
Error
,
'Create HDFS client error: '
+
str
(
e
))
nni_log
(
LogType
.
Error
,
'Create HDFS client error: '
+
str
(
e
))
raise
e
raise
e
return
_hdfs_client
def
main_loop
(
args
):
'''main loop logic for trial keeper'''
if
not
os
.
path
.
exists
(
LOG_DIR
):
os
.
makedirs
(
LOG_DIR
)
stdout_file
=
open
(
STDOUT_FULL_PATH
,
'a+'
)
stderr_file
=
open
(
STDERR_FULL_PATH
,
'a+'
)
trial_keeper_syslogger
=
RemoteLogger
(
args
.
nnimanager_ip
,
args
.
nnimanager_port
,
'trial_keeper'
,
StdOutputType
.
Stdout
,
args
.
log_collection
)
# redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout
=
RemoteLogger
(
args
.
nnimanager_ip
,
args
.
nnimanager_port
,
'trial'
,
StdOutputType
.
Stdout
,
args
.
log_collection
)
sys
.
stdout
=
sys
.
stderr
=
trial_keeper_syslogger
if
args
.
hdfs_output_dir
:
hdfs_output_dir
=
args
.
hdfs_output_dir
elif
args
.
pai_hdfs_output_dir
:
hdfs_output_dir
=
args
.
pai_hdfs_output_dir
hdfs_client
=
get_hdfs_client
(
args
)
if
hdfs_client
is
not
None
:
copyHdfsDirectoryToLocal
(
args
.
nni_hdfs_exp_dir
,
os
.
getcwd
(),
hdfs_client
)
copyHdfsDirectoryToLocal
(
args
.
nni_hdfs_exp_dir
,
os
.
getcwd
(),
hdfs_client
)
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
...
@@ -138,6 +155,52 @@ def check_version(args):
...
@@ -138,6 +155,52 @@ def check_version(args):
except
AttributeError
as
err
:
except
AttributeError
as
err
:
nni_log
(
LogType
.
Error
,
err
)
nni_log
(
LogType
.
Error
,
err
)
def
is_multi_phase
():
return
MULTI_PHASE
and
(
MULTI_PHASE
in
[
'True'
,
'true'
])
def
download_parameter
(
meta_list
,
args
):
"""
Download parameter file to local working directory.
meta_list format is defined in paiJobRestServer.ts
example meta_list:
[
{"experimentId":"yWFJarYa","trialId":"UpPkl","filePath":"/chec/nni/experiments/yWFJarYa/trials/UpPkl/parameter_1.cfg"},
{"experimentId":"yWFJarYa","trialId":"aIUMA","filePath":"/chec/nni/experiments/yWFJarYa/trials/aIUMA/parameter_1.cfg"}
]
"""
nni_log
(
LogType
.
Debug
,
str
(
meta_list
))
nni_log
(
LogType
.
Debug
,
'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'
.
format
(
NNI_SYS_DIR
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
))
nni_log
(
LogType
.
Debug
,
'NNI_SYS_DIR files: {}'
.
format
(
os
.
listdir
(
NNI_SYS_DIR
)))
for
meta
in
meta_list
:
if
meta
[
'experimentId'
]
==
NNI_EXP_ID
and
meta
[
'trialId'
]
==
NNI_TRIAL_JOB_ID
:
param_fp
=
os
.
path
.
join
(
NNI_SYS_DIR
,
os
.
path
.
basename
(
meta
[
'filePath'
]))
if
not
os
.
path
.
exists
(
param_fp
):
hdfs_client
=
get_hdfs_client
(
args
)
copyHdfsFileToLocal
(
meta
[
'filePath'
],
param_fp
,
hdfs_client
,
override
=
False
)
def
fetch_parameter_file
(
args
):
class
FetchThread
(
threading
.
Thread
):
def
__init__
(
self
,
args
):
super
(
FetchThread
,
self
).
__init__
()
self
.
args
=
args
def
run
(
self
):
uri
=
gen_parameter_meta_url
(
self
.
args
.
nnimanager_ip
,
self
.
args
.
nnimanager_port
)
nni_log
(
LogType
.
Info
,
uri
)
while
True
:
res
=
rest_get
(
uri
,
10
)
nni_log
(
LogType
.
Debug
,
'status code: {}'
.
format
(
res
.
status_code
))
if
res
.
status_code
==
200
:
meta_list
=
res
.
json
()
download_parameter
(
meta_list
,
self
.
args
)
else
:
nni_log
(
LogType
.
Warning
,
'rest response: {}'
.
format
(
str
(
res
)))
time
.
sleep
(
5
)
fetch_file_thread
=
FetchThread
(
args
)
fetch_file_thread
.
start
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
'''NNI Trial Keeper main function'''
'''NNI Trial Keeper main function'''
PARSER
=
argparse
.
ArgumentParser
()
PARSER
=
argparse
.
ArgumentParser
()
...
@@ -159,6 +222,8 @@ if __name__ == '__main__':
...
@@ -159,6 +222,8 @@ if __name__ == '__main__':
exit
(
1
)
exit
(
1
)
check_version
(
args
)
check_version
(
args
)
try
:
try
:
if
is_multi_phase
():
fetch_parameter_file
(
args
)
main_loop
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
except
SystemExit
as
se
:
nni_log
(
LogType
.
Info
,
'NNI trial keeper exit with code {}'
.
format
(
se
.
code
))
nni_log
(
LogType
.
Info
,
'NNI trial keeper exit with code {}'
.
format
(
se
.
code
))
...
...
tools/nni_trial_tool/url_utils.py
View file @
1500458a
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from
.constants
import
API_ROOT_URL
,
BASE_URL
,
STDOUT_API
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
,
VERSION_API
from
.constants
import
API_ROOT_URL
,
BASE_URL
,
STDOUT_API
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
,
VERSION_API
,
PARAMETER_META_API
def
gen_send_stdout_url
(
ip
,
port
):
def
gen_send_stdout_url
(
ip
,
port
):
'''Generate send stdout url'''
'''Generate send stdout url'''
...
@@ -26,4 +26,8 @@ def gen_send_stdout_url(ip, port):
...
@@ -26,4 +26,8 @@ def gen_send_stdout_url(ip, port):
def
gen_send_version_url
(
ip
,
port
):
def
gen_send_version_url
(
ip
,
port
):
'''Generate send error url'''
'''Generate send error url'''
return
'{0}:{1}{2}{3}/{4}/{5}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
VERSION_API
,
NNI_EXP_ID
,
NNI_TRIAL_JOB_ID
)
return
'{0}:{1}{2}{3}/{4}/{5}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
VERSION_API
,
NNI_EXP_ID
,
NNI_TRIAL_JOB_ID
)
\ No newline at end of file
def
gen_parameter_meta_url
(
ip
,
port
):
'''Generate send error url'''
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
PARAMETER_META_API
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment