Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
05913424
Commit
05913424
authored
Aug 05, 2019
by
suiguoxin
Browse files
Merge branch 'master' into quniform-tuners
parents
e3c8552f
1dab3118
Changes
86
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
794 additions
and
511 deletions
+794
-511
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+63
-4
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+81
-22
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
+168
-57
src/webui/src/components/trial-detail/Intermeidate.tsx
src/webui/src/components/trial-detail/Intermeidate.tsx
+4
-4
src/webui/src/components/trial-detail/Para.tsx
src/webui/src/components/trial-detail/Para.tsx
+15
-11
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+1
-3
src/webui/src/static/interface.ts
src/webui/src/static/interface.ts
+7
-2
src/webui/src/static/style/para.scss
src/webui/src/static/style/para.scss
+4
-0
src/webui/src/static/style/search.scss
src/webui/src/static/style/search.scss
+18
-0
src/webui/src/static/style/table.scss
src/webui/src/static/style/table.scss
+6
-2
src/webui/src/static/style/trialsDetail.scss
src/webui/src/static/style/trialsDetail.scss
+13
-0
src/webui/yarn.lock
src/webui/yarn.lock
+248
-403
test/cli_test.py
test/cli_test.py
+56
-0
test/naive_test.py
test/naive_test.py
+7
-0
test/pipelines-it-local-windows.yml
test/pipelines-it-local-windows.yml
+4
-0
test/pipelines-it-local.yml
test/pipelines-it-local.yml
+4
-0
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+2
-2
tools/nni_annotation/test_annotation.py
tools/nni_annotation/test_annotation.py
+2
-1
tools/nni_annotation/testcase/annotated/nas.py
tools/nni_annotation/testcase/annotated/nas.py
+49
-0
tools/nni_annotation/testcase/searchspace.json
tools/nni_annotation/testcase/searchspace.json
+42
-0
No files found.
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
05913424
...
...
@@ -193,13 +193,19 @@ class HyperoptTuner(Tuner):
HyperoptTuner is a tuner which using hyperopt algorithm.
"""
def
__init__
(
self
,
algorithm_name
,
optimize_mode
=
'minimize'
):
def
__init__
(
self
,
algorithm_name
,
optimize_mode
=
'minimize'
,
parallel_optimize
=
False
,
constant_liar_type
=
'min'
):
"""
Parameters
----------
algorithm_name : str
algorithm_name includes "tpe", "random_search" and anneal".
optimize_mode : str
parallel_optimize : bool
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
constant_liar_type : str
constant_liar_type including "min", "max" and "mean"
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
"""
self
.
algorithm_name
=
algorithm_name
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
...
...
@@ -208,6 +214,13 @@ class HyperoptTuner(Tuner):
self
.
rval
=
None
self
.
supplement_data_num
=
0
self
.
parallel
=
parallel_optimize
if
self
.
parallel
:
self
.
CL_rval
=
None
self
.
constant_liar_type
=
constant_liar_type
self
.
running_data
=
[]
self
.
optimal_y
=
None
def
_choose_tuner
(
self
,
algorithm_name
):
"""
Parameters
...
...
@@ -269,6 +282,10 @@ class HyperoptTuner(Tuner):
# but it can cause deplicate parameter rarely
total_params
=
self
.
get_suggestion
(
random_search
=
True
)
self
.
total_data
[
parameter_id
]
=
total_params
if
self
.
parallel
:
self
.
running_data
.
append
(
parameter_id
)
params
=
split_index
(
total_params
)
return
params
...
...
@@ -290,10 +307,39 @@ class HyperoptTuner(Tuner):
raise
RuntimeError
(
'Received parameter_id not in total_data.'
)
params
=
self
.
total_data
[
parameter_id
]
# code for parallel
if
self
.
parallel
:
constant_liar
=
kwargs
.
get
(
'constant_liar'
,
False
)
if
constant_liar
:
rval
=
self
.
CL_rval
else
:
rval
=
self
.
rval
self
.
running_data
.
remove
(
parameter_id
)
# update the reward of optimal_y
if
self
.
optimal_y
is
None
:
if
self
.
constant_liar_type
==
'mean'
:
self
.
optimal_y
=
[
reward
,
1
]
else
:
self
.
optimal_y
=
reward
else
:
if
self
.
constant_liar_type
==
'mean'
:
_sum
=
self
.
optimal_y
[
0
]
+
reward
_number
=
self
.
optimal_y
[
1
]
+
1
self
.
optimal_y
=
[
_sum
,
_number
]
elif
self
.
constant_liar_type
==
'min'
:
self
.
optimal_y
=
min
(
self
.
optimal_y
,
reward
)
elif
self
.
constant_liar_type
==
'max'
:
self
.
optimal_y
=
max
(
self
.
optimal_y
,
reward
)
logger
.
debug
(
"Update optimal_y with reward, optimal_y = %s"
,
self
.
optimal_y
)
else
:
rval
=
self
.
rval
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
reward
=
-
reward
rval
=
self
.
rval
domain
=
rval
.
domain
trials
=
rval
.
trials
...
...
@@ -378,13 +424,26 @@ class HyperoptTuner(Tuner):
total_params : dict
parameter suggestion
"""
if
self
.
parallel
and
len
(
self
.
total_data
)
>
20
and
len
(
self
.
running_data
)
and
self
.
optimal_y
is
not
None
:
self
.
CL_rval
=
copy
.
deepcopy
(
self
.
rval
)
if
self
.
constant_liar_type
==
'mean'
:
_constant_liar_y
=
self
.
optimal_y
[
0
]
/
self
.
optimal_y
[
1
]
else
:
_constant_liar_y
=
self
.
optimal_y
for
_parameter_id
in
self
.
running_data
:
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
None
,
value
=
_constant_liar_y
,
constant_liar
=
True
)
rval
=
self
.
CL_rval
rval
=
self
.
rval
random_state
=
np
.
random
.
randint
(
2
**
31
-
1
)
else
:
rval
=
self
.
rval
random_state
=
rval
.
rstate
.
randint
(
2
**
31
-
1
)
trials
=
rval
.
trials
algorithm
=
rval
.
algo
new_ids
=
rval
.
trials
.
new_trial_ids
(
1
)
rval
.
trials
.
refresh
()
random_state
=
rval
.
rstate
.
randint
(
2
**
31
-
1
)
if
random_search
:
new_trials
=
hp
.
rand
.
suggest
(
new_ids
,
rval
.
domain
,
trials
,
random_state
)
...
...
src/webui/src/components/TrialsDetail.tsx
View file @
05913424
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
{
MANAGER_IP
}
from
'
../static/const
'
;
import
{
Row
,
Col
,
Tabs
,
Input
,
Select
,
Button
,
Icon
}
from
'
antd
'
;
import
{
Row
,
Col
,
Tabs
,
Select
,
Button
,
Icon
}
from
'
antd
'
;
const
Option
=
Select
.
Option
;
import
{
TableObj
,
Parameters
}
from
'
../static/interface
'
;
import
{
TableObj
,
Parameters
,
ExperimentInfo
}
from
'
../static/interface
'
;
import
{
getFinal
}
from
'
../static/function
'
;
import
DefaultPoint
from
'
./trial-detail/DefaultMetricPoint
'
;
import
Duration
from
'
./trial-detail/Duration
'
;
...
...
@@ -13,6 +13,7 @@ import Intermediate from './trial-detail/Intermeidate';
import
TableList
from
'
./trial-detail/TableList
'
;
const
TabPane
=
Tabs
.
TabPane
;
import
'
../static/style/trialsDetail.scss
'
;
import
'
../static/style/search.scss
'
;
interface
TrialDetailState
{
accSource
:
object
;
...
...
@@ -20,8 +21,6 @@ interface TrialDetailState {
tableListSource
:
Array
<
TableObj
>
;
searchResultSource
:
Array
<
TableObj
>
;
isHasSearch
:
boolean
;
experimentStatus
:
string
;
experimentPlatform
:
string
;
experimentLogCollection
:
boolean
;
entriesTable
:
number
;
// table components val
entriesInSelect
:
string
;
...
...
@@ -31,6 +30,9 @@ interface TrialDetailState {
hyperCounts
:
number
;
// user click the hyper-parameter counts
durationCounts
:
number
;
intermediateCounts
:
number
;
experimentInfo
:
ExperimentInfo
;
searchFilter
:
string
;
searchPlaceHolder
:
string
;
}
interface
TrialsDetailProps
{
...
...
@@ -46,6 +48,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
public
interAllTableList
=
2
;
public
tableList
:
TableList
|
null
;
public
searchInput
:
HTMLInputElement
|
null
;
private
titleOfacc
=
(
<
Title1
text
=
"Default metric"
icon
=
"3.png"
/>
...
...
@@ -74,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
accNodata
:
''
,
tableListSource
:
[],
searchResultSource
:
[],
experimentStatus
:
''
,
experimentPlatform
:
''
,
experimentLogCollection
:
false
,
entriesTable
:
20
,
entriesInSelect
:
'
20
'
,
...
...
@@ -85,7 +86,13 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
isMultiPhase
:
false
,
hyperCounts
:
0
,
durationCounts
:
0
,
intermediateCounts
:
0
intermediateCounts
:
0
,
experimentInfo
:
{
platform
:
''
,
optimizeMode
:
'
maximize
'
},
searchFilter
:
'
id
'
,
searchPlaceHolder
:
'
Search by id
'
};
}
...
...
@@ -212,15 +219,33 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}));
}
}
else
{
const
{
tableListSource
}
=
this
.
state
;
const
{
tableListSource
,
searchFilter
}
=
this
.
state
;
const
searchResultList
:
Array
<
TableObj
>
=
[];
Object
.
keys
(
tableListSource
).
map
(
key
=>
{
const
item
=
tableListSource
[
key
];
if
(
item
.
sequenceId
.
toString
()
===
targetValue
||
item
.
id
.
includes
(
targetValue
)
||
item
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
())
)
{
searchResultList
.
push
(
item
);
switch
(
searchFilter
)
{
case
'
id
'
:
if
(
item
.
id
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
searchResultList
.
push
(
item
);
}
break
;
case
'
Trial No.
'
:
if
(
item
.
sequenceId
.
toString
()
===
targetValue
)
{
searchResultList
.
push
(
item
);
}
break
;
case
'
status
'
:
if
(
item
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
searchResultList
.
push
(
item
);
}
break
;
case
'
parameters
'
:
const
strParameters
=
JSON
.
stringify
(
item
.
description
.
parameters
,
null
,
4
);
if
(
strParameters
.
includes
(
targetValue
))
{
searchResultList
.
push
(
item
);
}
break
;
default
:
}
});
if
(
this
.
_isMounted
)
{
...
...
@@ -282,6 +307,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
alert
(
'
TableList component was not properly initialized.
'
);
}
getSearchFilter
=
(
value
:
string
)
=>
{
// clear input value and re-render table
if
(
this
.
searchInput
!==
null
)
{
this
.
searchInput
.
value
=
''
;
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isHasSearch
:
false
}));
}
}
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
searchFilter
:
value
,
searchPlaceHolder
:
`Search by
${
value
}
`
}));
}
}
// get and set logCollection val
checkExperimentPlatform
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
...
...
@@ -289,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
})
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
trainingPlatform
=
res
.
data
.
params
.
trainingServicePlatform
!==
undefined
const
trainingPlatform
:
string
=
res
.
data
.
params
.
trainingServicePlatform
!==
undefined
?
res
.
data
.
params
.
trainingServicePlatform
:
...
...
@@ -299,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
let
expLogCollection
:
boolean
=
false
;
const
isMultiy
:
boolean
=
res
.
data
.
params
.
multiPhase
!==
undefined
?
res
.
data
.
params
.
multiPhase
:
false
;
const
tuner
=
res
.
data
.
params
.
tuner
;
// I'll set optimize is maximize if user not set optimize
let
optimize
:
string
=
'
maximize
'
;
if
(
tuner
!==
undefined
)
{
if
(
tuner
.
classArgs
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
===
'
minimize
'
)
{
optimize
=
'
minimize
'
;
}
}
}
}
if
(
logCollection
!==
undefined
&&
logCollection
!==
'
none
'
)
{
expLogCollection
=
true
;
}
if
(
this
.
_isMounted
)
{
this
.
setState
({
experiment
P
latform
:
trainingPlatform
,
experiment
Info
:
{
p
latform
:
trainingPlatform
,
optimizeMode
:
optimize
},
searchSpace
:
res
.
data
.
params
.
searchSpace
,
experimentLogCollection
:
expLogCollection
,
isMultiPhase
:
isMultiy
...
...
@@ -343,8 +393,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
const
{
tableListSource
,
searchResultSource
,
isHasSearch
,
isMultiPhase
,
entriesTable
,
experiment
Platform
,
searchSpace
,
experimentLogCollection
,
whichGraph
entriesTable
,
experiment
Info
,
searchSpace
,
experimentLogCollection
,
whichGraph
,
searchPlaceHolder
}
=
this
.
state
;
const
source
=
isHasSearch
?
searchResultSource
:
tableListSource
;
return
(
...
...
@@ -354,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
TabPane
tab
=
{
this
.
titleOfacc
}
key
=
"1"
>
<
Row
className
=
"graph"
>
<
DefaultPoint
height
=
{
4
3
2
}
height
=
{
4
0
2
}
showSource
=
{
source
}
whichGraph
=
{
whichGraph
}
optimize
=
{
experimentInfo
.
optimizeMode
}
/>
</
Row
>
</
TabPane
>
...
...
@@ -408,11 +459,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Compare
</
Button
>
<
Input
<
Select
defaultValue
=
"id"
className
=
"filter"
onSelect
=
{
this
.
getSearchFilter
}
>
<
Option
value
=
"id"
>
Id
</
Option
>
<
Option
value
=
"Trial No."
>
Trial No.
</
Option
>
<
Option
value
=
"status"
>
Status
</
Option
>
<
Option
value
=
"parameters"
>
Parameters
</
Option
>
</
Select
>
<
input
type
=
"text"
placeholder
=
"Search by id, trial No. or status"
className
=
"search-input"
placeholder
=
{
searchPlaceHolder
}
onChange
=
{
this
.
searchTrial
}
style
=
{
{
width
:
230
,
marginLeft
:
6
}
}
style
=
{
{
width
:
230
}
}
ref
=
{
text
=>
(
this
.
searchInput
)
=
text
}
/>
</
Col
>
</
Row
>
...
...
@@ -420,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
entries
=
{
entriesTable
}
tableSource
=
{
source
}
isMultiPhase
=
{
isMultiPhase
}
platform
=
{
experiment
P
latform
}
platform
=
{
experiment
Info
.
p
latform
}
updateList
=
{
this
.
getDetailSource
}
logCollection
=
{
experimentLogCollection
}
ref
=
{
(
tabList
)
=>
this
.
tableList
=
tabList
}
...
...
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
View file @
05913424
import
*
as
React
from
'
react
'
;
import
{
Switch
}
from
'
antd
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
filterByStatus
}
from
'
../../static/function
'
;
import
{
TableObj
,
DetailAccurPoint
,
TooltipForAccuracy
}
from
'
../../static/interface
'
;
...
...
@@ -10,32 +11,36 @@ interface DefaultPointProps {
showSource
:
Array
<
TableObj
>
;
height
:
number
;
whichGraph
:
string
;
optimize
:
string
;
}
interface
DefaultPointState
{
defaultSource
:
object
;
accNodata
:
string
;
succeedTrials
:
number
;
isViewBestCurve
:
boolean
;
}
class
DefaultPoint
extends
React
.
Component
<
DefaultPointProps
,
DefaultPointState
>
{
public
_isMounted
=
false
;
public
_is
Default
Mounted
=
false
;
constructor
(
props
:
DefaultPointProps
)
{
super
(
props
);
this
.
state
=
{
defaultSource
:
{},
accNodata
:
''
,
succeedTrials
:
10000000
succeedTrials
:
10000000
,
isViewBestCurve
:
false
};
}
defaultMetric
=
(
succeedSource
:
Array
<
TableObj
>
)
=>
{
defaultMetric
=
(
succeedSource
:
Array
<
TableObj
>
,
isCurve
:
boolean
)
=>
{
const
{
optimize
}
=
this
.
props
;
const
accSource
:
Array
<
DetailAccurPoint
>
=
[];
const
showSource
:
Array
<
TableObj
>
=
succeedSource
.
filter
(
filterByStatus
);
const
lengthOfSource
=
showSource
.
length
;
const
tooltipDefault
=
lengthOfSource
===
0
?
'
No data
'
:
''
;
if
(
this
.
_isMounted
===
true
)
{
if
(
this
.
_is
Default
Mounted
===
true
)
{
this
.
setState
(()
=>
({
succeedTrials
:
lengthOfSource
,
accNodata
:
tooltipDefault
...
...
@@ -55,94 +60,195 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
type
:
'
value
'
,
}
};
if
(
this
.
_isMounted
===
true
)
{
if
(
this
.
_is
Default
Mounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
nullGraph
}));
}
}
else
{
const
resultList
:
Array
<
number
|
string
>
[]
=
[];
const
resultList
:
Array
<
number
|
object
>
[]
=
[];
const
lineListDefault
:
Array
<
number
>
=
[];
Object
.
keys
(
showSource
).
map
(
item
=>
{
const
temp
=
showSource
[
item
];
if
(
temp
.
acc
!==
undefined
)
{
if
(
temp
.
acc
.
default
!==
undefined
)
{
const
searchSpace
=
temp
.
description
.
parameters
;
lineListDefault
.
push
(
temp
.
acc
.
default
);
accSource
.
push
({
acc
:
temp
.
acc
.
default
,
index
:
temp
.
sequenceId
,
searchSpace
:
JSON
.
stringify
(
searchSpace
)
searchSpace
:
searchSpace
});
}
}
});
// deal with best metric line
const
bestCurve
:
Array
<
number
|
object
>
[]
=
[];
// best curve data source
bestCurve
.
push
([
0
,
lineListDefault
[
0
],
accSource
[
0
].
searchSpace
]);
// push the first value
if
(
optimize
===
'
maximize
'
)
{
for
(
let
i
=
1
;
i
<
lineListDefault
.
length
;
i
++
)
{
const
val
=
lineListDefault
[
i
];
const
latest
=
bestCurve
[
bestCurve
.
length
-
1
][
1
];
if
(
val
>=
latest
)
{
bestCurve
.
push
([
i
,
val
,
accSource
[
i
].
searchSpace
]);
}
else
{
bestCurve
.
push
([
i
,
latest
,
accSource
[
i
].
searchSpace
]);
}
}
}
else
{
for
(
let
i
=
1
;
i
<
lineListDefault
.
length
;
i
++
)
{
const
val
=
lineListDefault
[
i
];
const
latest
=
bestCurve
[
bestCurve
.
length
-
1
][
1
];
if
(
val
<=
latest
)
{
bestCurve
.
push
([
i
,
val
,
accSource
[
i
].
searchSpace
]);
}
else
{
bestCurve
.
push
([
i
,
latest
,
accSource
[
i
].
searchSpace
]);
}
}
}
Object
.
keys
(
accSource
).
map
(
item
=>
{
const
items
=
accSource
[
item
];
let
temp
:
Array
<
number
|
string
>
;
temp
=
[
items
.
index
,
items
.
acc
,
JSON
.
parse
(
items
.
searchSpace
)
];
let
temp
:
Array
<
number
|
object
>
;
temp
=
[
items
.
index
,
items
.
acc
,
items
.
searchSpace
];
resultList
.
push
(
temp
);
});
// isViewBestCurve: false show default metric graph
// isViewBestCurve: true show best curve
if
(
isCurve
===
true
)
{
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
this
.
drawBestcurve
(
bestCurve
,
resultList
)
}));
}
}
else
{
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
this
.
drawDefaultMetric
(
resultList
)
}));
}
}
}
}
const
allAcuracy
=
{
grid
:
{
left
:
'
8%
'
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
{
if
(
data
.
data
[
0
]
<
resultList
.
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
formatter
:
function
(
data
:
TooltipForAccuracy
)
{
const
result
=
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Default metric:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
return
result
;
drawBestcurve
=
(
realDefault
:
Array
<
number
|
object
>
[],
resultList
:
Array
<
number
|
object
>
[])
=>
{
return
{
grid
:
{
left
:
'
8%
'
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
{
if
(
data
.
data
[
0
]
<
realDefault
.
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
formatter
:
function
(
data
:
TooltipForAccuracy
)
{
const
result
=
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Optimization curve:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
return
result
;
}
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
resultList
},
{
type
:
'
line
'
,
lineStyle
:
{
color
:
'
#FF6600
'
},
data
:
realDefault
}]
};
}
drawDefaultMetric
=
(
resultList
:
Array
<
number
|
object
>
[])
=>
{
return
{
grid
:
{
left
:
'
8%
'
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
{
if
(
data
.
data
[
0
]
<
resultList
.
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
resultList
}]
};
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
allAcuracy
}));
}
formatter
:
function
(
data
:
TooltipForAccuracy
)
{
const
result
=
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Default metric:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
return
result
;
}
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
resultList
}]
};
}
loadDefault
=
(
checked
:
boolean
)
=>
{
// checked: true show best metric curve
const
{
showSource
}
=
this
.
props
;
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
defaultMetric
(
showSource
,
checked
);
// ** deal with data and then update view layer
this
.
setState
(()
=>
({
isViewBestCurve
:
checked
}));
}
}
// update parent component state
componentWillReceiveProps
(
nextProps
:
DefaultPointProps
)
{
const
{
whichGraph
,
showSource
}
=
nextProps
;
const
{
isViewBestCurve
}
=
this
.
state
;
if
(
whichGraph
===
'
1
'
)
{
this
.
defaultMetric
(
showSource
);
this
.
defaultMetric
(
showSource
,
isViewBestCurve
);
}
}
shouldComponentUpdate
(
nextProps
:
DefaultPointProps
,
nextState
:
DefaultPointState
)
{
const
{
whichGraph
}
=
nextProps
;
const
succTrial
=
this
.
state
.
succeedTrials
;
const
{
succeedTrials
}
=
nextState
;
if
(
whichGraph
===
'
1
'
)
{
const
{
succeedTrials
,
isViewBestCurve
}
=
nextState
;
const
succTrial
=
this
.
state
.
succeedTrials
;
const
isViewBestCurveBefore
=
this
.
state
.
isViewBestCurve
;
if
(
isViewBestCurveBefore
!==
isViewBestCurve
)
{
return
true
;
}
if
(
succeedTrials
!==
succTrial
)
{
return
true
;
}
...
...
@@ -152,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
this
.
_is
Default
Mounted
=
true
;
}
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
this
.
_is
Default
Mounted
=
false
;
}
render
()
{
...
...
@@ -164,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
const
{
defaultSource
,
accNodata
}
=
this
.
state
;
return
(
<
div
>
<
div
className
=
"default-metric"
>
<
div
className
=
"position"
>
<
span
className
=
"bold"
>
optimization curve
</
span
>
<
Switch
defaultChecked
=
{
false
}
onChange
=
{
this
.
loadDefault
}
/>
</
div
>
</
div
>
<
ReactEcharts
option
=
{
defaultSource
}
style
=
{
{
...
...
@@ -173,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
}
theme
=
"my_theme"
notMerge
=
{
true
}
// update now
// lazyUpdate={true}
/>
<
div
className
=
"showMess"
>
{
accNodata
}
</
div
>
</
div
>
...
...
src/webui/src/components/trial-detail/Intermeidate.tsx
View file @
05913424
...
...
@@ -114,7 +114,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
yAxis
:
{
type
:
'
value
'
,
name
:
'
m
etric
'
name
:
'
M
etric
'
},
series
:
trialIntermediate
};
...
...
@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
yAxis
:
{
type
:
'
value
'
,
name
:
'
m
etric
'
name
:
'
M
etric
'
}
};
if
(
this
.
_isMounted
)
{
...
...
@@ -283,9 +283,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
{
/* style in para.scss */
}
<
Row
className
=
"meline intermediate"
>
<
Col
span
=
{
8
}
/>
<
Col
span
=
{
3
}
style
=
{
{
height
:
34
}
}
>
<
Col
span
=
{
3
}
className
=
"inter-filter-btn"
>
{
/* filter message */
}
<
span
>
f
ilter
</
span
>
<
span
>
F
ilter
</
span
>
<
Switch
defaultChecked
=
{
false
}
onChange
=
{
this
.
switchTurn
}
...
...
src/webui/src/components/trial-detail/Para.tsx
View file @
05913424
...
...
@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> {
let
temp
:
Array
<
number
>
=
[];
for
(
let
i
=
0
;
i
<
dimName
.
length
;
i
++
)
{
if
(
'
type
'
in
parallelAxis
[
i
])
{
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]].
toString
()
);
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]].
toString
());
}
else
{
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]]
);
// default metric
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]]);
}
}
paraYdata
.
push
(
temp
);
...
...
@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> {
break
;
// support log distribute
case
'
loguniform
'
:
parallelAxis
.
push
({
dim
:
i
,
name
:
dimName
[
i
],
type
:
'
log
'
,
});
if
(
lenOfDataSource
>
1
)
{
parallelAxis
.
push
({
dim
:
i
,
name
:
dimName
[
i
],
type
:
'
log
'
,
});
}
else
{
parallelAxis
.
push
({
dim
:
i
,
name
:
dimName
[
i
]
});
}
break
;
default
:
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
05913424
...
...
@@ -321,9 +321,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key
:
'
sequenceId
'
,
width
:
120
,
className
:
'
tableHead
'
,
sorter
:
(
a
:
TableObj
,
b
:
TableObj
)
=>
(
a
.
sequenceId
as
number
)
-
(
b
.
sequenceId
as
number
)
sorter
:
(
a
:
TableObj
,
b
:
TableObj
)
=>
(
a
.
sequenceId
as
number
)
-
(
b
.
sequenceId
as
number
)
});
break
;
case
'
ID
'
:
...
...
src/webui/src/static/interface.ts
View file @
05913424
...
...
@@ -59,7 +59,7 @@ interface AccurPoint {
interface
DetailAccurPoint
{
acc
:
number
;
index
:
number
;
searchSpace
:
string
;
searchSpace
:
object
;
}
interface
TooltipForIntermediate
{
...
...
@@ -117,8 +117,13 @@ interface Intermedia {
hyperPara
:
object
;
// each trial hyperpara value
}
interface
ExperimentInfo
{
platform
:
string
;
optimizeMode
:
string
;
}
export
{
TableObj
,
Parameters
,
Experiment
,
AccurPoint
,
TrialNumber
,
TrialJob
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalResult
,
FinalType
,
TooltipForIntermediate
,
SearchSpace
,
Intermedia
TooltipForIntermediate
,
SearchSpace
,
Intermedia
,
ExperimentInfo
};
src/webui/src/static/style/para.scss
View file @
05913424
...
...
@@ -36,6 +36,10 @@
.strange
{
margin-top
:
2px
;
}
.inter-filter-btn
{
height
:
34px
;
line-height
:
34px
;
}
.range
{
.heng
{
margin-left
:
6px
;
...
...
src/webui/src/static/style/search.scss
View file @
05913424
...
...
@@ -11,6 +11,24 @@
color
:
#0071BC
;
border-radius
:
0
;
}
.filter
{
width
:
100px
;
margin-left
:
8px
;
.ant-select-selection-selected-value
{
font-size
:
14px
;
}
}
.search-input
{
height
:
32px
;
outline
:
none
;
border
:
1px
solid
#d9d9d9
;
border-left
:
none
;
padding-left
:
8px
;
color
:
#333
;
}
.
search-input
:
:
placeholder
{
color
:
DarkGray
;
}
}
.entry
{
width
:
120px
;
...
...
src/webui/src/static/style/table.scss
View file @
05913424
...
...
@@ -31,14 +31,12 @@
text-align
:
center
;
color
:
#212121
;
font-size
:
14px
;
/* background-color: #f2f2f2; */
}
th
{
padding
:
2px
;
background-color
:white
!
important
;
font-size
:
14px
;
color
:
#808080
;
border-bottom
:
1px
solid
#d0d0d0
;
text-align
:
center
;
}
...
...
@@ -105,3 +103,9 @@
.ant-table-selection
{
display
:
none
;
}
/* fix the border-bottom bug in firefox and edge */
.
ant-table-thead
>
tr
>
th
.
ant-table-column-sorters
:
:
before
{
padding-bottom
:
25px
;
border-bottom
:
1px
solid
#e8e8e8
;
}
\ No newline at end of file
src/webui/src/static/style/trialsDetail.scss
View file @
05913424
...
...
@@ -70,3 +70,16 @@
.allList
{
margin-top
:
15px
;
}
.default-metric
{
width
:
90%
;
text-align
:
right
;
margin-top
:
15px
;
.position
{
color
:
#333
;
.bold
{
font-weight
:
600
;
margin-right
:
10px
;
}
}
}
src/webui/yarn.lock
View file @
05913424
This diff is collapsed.
Click to expand it.
test/cli_test.py
0 → 100644
View file @
05913424
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
sys
import
time
import
traceback
from
utils
import
GREEN
,
RED
,
CLEAR
,
setup_experiment
def
test_nni_cli
():
import
nnicli
as
nc
config_file
=
'config_test/examples/mnist.test.yml'
try
:
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time
.
sleep
(
6
)
print
(
GREEN
+
'Testing nnicli:'
+
config_file
+
CLEAR
)
nc
.
start_nni
(
config_file
)
time
.
sleep
(
3
)
nc
.
set_endpoint
(
'http://localhost:8080'
)
print
(
nc
.
version
())
print
(
nc
.
get_job_statistics
())
print
(
nc
.
get_experiment_status
())
nc
.
list_trial_jobs
()
print
(
GREEN
+
'Test nnicli {}: TEST PASS'
.
format
(
config_file
)
+
CLEAR
)
except
Exception
as
error
:
print
(
RED
+
'Test nnicli {}: TEST FAIL'
.
format
(
config_file
)
+
CLEAR
)
print
(
'%r'
%
error
)
traceback
.
print_exc
()
raise
error
finally
:
nc
.
stop_nni
()
if
__name__
==
'__main__'
:
installed
=
(
sys
.
argv
[
-
1
]
!=
'--preinstall'
)
setup_experiment
(
installed
)
test_nni_cli
()
test/naive_test.py
View file @
05913424
...
...
@@ -88,6 +88,7 @@ def stop_experiment_test():
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8080'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8888'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8989'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8990'
],
check
=
True
)
# test cmd 'nnictl stop id`
experiment_id
=
get_experiment_id
(
EXPERIMENT_URL
)
...
...
@@ -96,6 +97,12 @@ def stop_experiment_test():
snooze
()
assert
not
detect_port
(
8080
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
# test cmd `nnictl stop --port`
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'--port'
,
'8990'
])
assert
proc
.
returncode
==
0
,
'`nnictl stop %s` failed with code %d'
%
(
experiment_id
,
proc
.
returncode
)
snooze
()
assert
not
detect_port
(
8990
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
# test cmd `nnictl stop all`
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'all'
])
assert
proc
.
returncode
==
0
,
'`nnictl stop all` failed with code %d'
%
proc
.
returncode
...
...
test/pipelines-it-local-windows.yml
View file @
05913424
...
...
@@ -36,3 +36,7 @@ jobs:
cd test
python metrics_test.py
displayName
:
'
Trial
job
metrics
test'
-
script
:
|
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName
:
'
nnicli
test'
test/pipelines-it-local.yml
View file @
05913424
...
...
@@ -37,3 +37,7 @@ jobs:
cd test
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName
:
'
Trial
job
metrics
test'
-
script
:
|
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName
:
'
nnicli
test'
tools/nni_annotation/search_space_generator.py
View file @
05913424
...
...
@@ -57,8 +57,8 @@ class SearchSpaceGenerator(ast.NodeTransformer):
key
=
self
.
module_name
+
'/'
+
mutable_block
args
[
0
].
s
=
key
if
key
not
in
self
.
search_space
:
self
.
search_space
[
key
]
=
dict
()
self
.
search_space
[
key
][
mutable_layer
]
=
{
self
.
search_space
[
key
]
=
{
'_type'
:
'mutable_layer'
,
'_value'
:
{}}
self
.
search_space
[
key
][
'_value'
][
mutable_layer
]
=
{
'layer_choice'
:
[
k
.
s
for
k
in
args
[
2
].
keys
],
'optional_inputs'
:
[
k
.
s
for
k
in
args
[
5
].
keys
],
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
...
...
tools/nni_annotation/test_annotation.py
View file @
05913424
...
...
@@ -44,8 +44,9 @@ class AnnotationTestCase(TestCase):
self
.
assertEqual
(
search_space
,
json
.
load
(
f
))
def
test_code_generator
(
self
):
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated'
)
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated'
,
nas_mode
=
'classic_mode'
)
self
.
assertEqual
(
code_dir
,
'_generated'
)
self
.
_assert_source_equal
(
'testcase/annotated/nas.py'
,
'_generated/nas.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/mnist.py'
,
'_generated/mnist.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/dir/simple.py'
,
'_generated/dir/simple.py'
)
with
open
(
'testcase/usercode/nonpy.txt'
)
as
src
,
open
(
'_generated/nonpy.txt'
)
as
dst
:
...
...
tools/nni_annotation/testcase/annotated/nas.py
0 → 100644
View file @
05913424
import
nni
import
time
def
add_one
(
inputs
):
return
inputs
+
1
def
add_two
(
inputs
):
return
inputs
+
2
def
add_three
(
inputs
):
return
inputs
+
3
def
add_four
(
inputs
):
return
inputs
+
4
def
main
():
images
=
5
layer_1_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_0'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'images'
:
images
},
1
,
'classic_mode'
)
layer_2_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_1'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'layer_1_out'
:
layer_1_out
},
1
,
'classic_mode'
)
layer_3_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_2'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'layer_1_out'
:
layer_1_out
,
'layer_2_out'
:
layer_2_out
},
1
,
'classic_mode'
)
nni
.
report_intermediate_result
(
layer_1_out
)
time
.
sleep
(
2
)
nni
.
report_intermediate_result
(
layer_2_out
)
time
.
sleep
(
2
)
nni
.
report_intermediate_result
(
layer_3_out
)
time
.
sleep
(
2
)
layer_3_out
=
layer_3_out
+
10
nni
.
report_final_result
(
layer_3_out
)
if
__name__
==
'__main__'
:
main
()
tools/nni_annotation/testcase/searchspace.json
View file @
05913424
...
...
@@ -143,5 +143,47 @@
"(2 * 3 + 4)"
,
"(lambda x: 1 + x)"
]
},
"nas/mutable_block_39"
:
{
"_type"
:
"mutable_layer"
,
"_value"
:
{
"mutable_layer_0"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"images"
],
"optional_input_size"
:
1
},
"mutable_layer_1"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"layer_1_out"
],
"optional_input_size"
:
1
},
"mutable_layer_2"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"layer_1_out"
,
"layer_2_out"
],
"optional_input_size"
:
1
}
}
}
}
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment