Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
05913424
Commit
05913424
authored
Aug 05, 2019
by
suiguoxin
Browse files
Merge branch 'master' into quniform-tuners
parents
e3c8552f
1dab3118
Changes
86
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
794 additions
and
511 deletions
+794
-511
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+63
-4
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+81
-22
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
+168
-57
src/webui/src/components/trial-detail/Intermeidate.tsx
src/webui/src/components/trial-detail/Intermeidate.tsx
+4
-4
src/webui/src/components/trial-detail/Para.tsx
src/webui/src/components/trial-detail/Para.tsx
+15
-11
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+1
-3
src/webui/src/static/interface.ts
src/webui/src/static/interface.ts
+7
-2
src/webui/src/static/style/para.scss
src/webui/src/static/style/para.scss
+4
-0
src/webui/src/static/style/search.scss
src/webui/src/static/style/search.scss
+18
-0
src/webui/src/static/style/table.scss
src/webui/src/static/style/table.scss
+6
-2
src/webui/src/static/style/trialsDetail.scss
src/webui/src/static/style/trialsDetail.scss
+13
-0
src/webui/yarn.lock
src/webui/yarn.lock
+248
-403
test/cli_test.py
test/cli_test.py
+56
-0
test/naive_test.py
test/naive_test.py
+7
-0
test/pipelines-it-local-windows.yml
test/pipelines-it-local-windows.yml
+4
-0
test/pipelines-it-local.yml
test/pipelines-it-local.yml
+4
-0
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+2
-2
tools/nni_annotation/test_annotation.py
tools/nni_annotation/test_annotation.py
+2
-1
tools/nni_annotation/testcase/annotated/nas.py
tools/nni_annotation/testcase/annotated/nas.py
+49
-0
tools/nni_annotation/testcase/searchspace.json
tools/nni_annotation/testcase/searchspace.json
+42
-0
No files found.
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
05913424
...
@@ -193,13 +193,19 @@ class HyperoptTuner(Tuner):
...
@@ -193,13 +193,19 @@ class HyperoptTuner(Tuner):
HyperoptTuner is a tuner which using hyperopt algorithm.
HyperoptTuner is a tuner which using hyperopt algorithm.
"""
"""
def
__init__
(
self
,
algorithm_name
,
optimize_mode
=
'minimize'
):
def
__init__
(
self
,
algorithm_name
,
optimize_mode
=
'minimize'
,
parallel_optimize
=
False
,
constant_liar_type
=
'min'
):
"""
"""
Parameters
Parameters
----------
----------
algorithm_name : str
algorithm_name : str
algorithm_name includes "tpe", "random_search" and anneal".
algorithm_name includes "tpe", "random_search" and anneal".
optimize_mode : str
optimize_mode : str
parallel_optimize : bool
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
constant_liar_type : str
constant_liar_type including "min", "max" and "mean"
More detail could reference: docs/en_US/Tuner/HyperoptTuner.md
"""
"""
self
.
algorithm_name
=
algorithm_name
self
.
algorithm_name
=
algorithm_name
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
...
@@ -208,6 +214,13 @@ class HyperoptTuner(Tuner):
...
@@ -208,6 +214,13 @@ class HyperoptTuner(Tuner):
self
.
rval
=
None
self
.
rval
=
None
self
.
supplement_data_num
=
0
self
.
supplement_data_num
=
0
self
.
parallel
=
parallel_optimize
if
self
.
parallel
:
self
.
CL_rval
=
None
self
.
constant_liar_type
=
constant_liar_type
self
.
running_data
=
[]
self
.
optimal_y
=
None
def
_choose_tuner
(
self
,
algorithm_name
):
def
_choose_tuner
(
self
,
algorithm_name
):
"""
"""
Parameters
Parameters
...
@@ -269,6 +282,10 @@ class HyperoptTuner(Tuner):
...
@@ -269,6 +282,10 @@ class HyperoptTuner(Tuner):
# but it can cause deplicate parameter rarely
# but it can cause deplicate parameter rarely
total_params
=
self
.
get_suggestion
(
random_search
=
True
)
total_params
=
self
.
get_suggestion
(
random_search
=
True
)
self
.
total_data
[
parameter_id
]
=
total_params
self
.
total_data
[
parameter_id
]
=
total_params
if
self
.
parallel
:
self
.
running_data
.
append
(
parameter_id
)
params
=
split_index
(
total_params
)
params
=
split_index
(
total_params
)
return
params
return
params
...
@@ -290,10 +307,39 @@ class HyperoptTuner(Tuner):
...
@@ -290,10 +307,39 @@ class HyperoptTuner(Tuner):
raise
RuntimeError
(
'Received parameter_id not in total_data.'
)
raise
RuntimeError
(
'Received parameter_id not in total_data.'
)
params
=
self
.
total_data
[
parameter_id
]
params
=
self
.
total_data
[
parameter_id
]
# code for parallel
if
self
.
parallel
:
constant_liar
=
kwargs
.
get
(
'constant_liar'
,
False
)
if
constant_liar
:
rval
=
self
.
CL_rval
else
:
rval
=
self
.
rval
self
.
running_data
.
remove
(
parameter_id
)
# update the reward of optimal_y
if
self
.
optimal_y
is
None
:
if
self
.
constant_liar_type
==
'mean'
:
self
.
optimal_y
=
[
reward
,
1
]
else
:
self
.
optimal_y
=
reward
else
:
if
self
.
constant_liar_type
==
'mean'
:
_sum
=
self
.
optimal_y
[
0
]
+
reward
_number
=
self
.
optimal_y
[
1
]
+
1
self
.
optimal_y
=
[
_sum
,
_number
]
elif
self
.
constant_liar_type
==
'min'
:
self
.
optimal_y
=
min
(
self
.
optimal_y
,
reward
)
elif
self
.
constant_liar_type
==
'max'
:
self
.
optimal_y
=
max
(
self
.
optimal_y
,
reward
)
logger
.
debug
(
"Update optimal_y with reward, optimal_y = %s"
,
self
.
optimal_y
)
else
:
rval
=
self
.
rval
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
reward
=
-
reward
reward
=
-
reward
rval
=
self
.
rval
domain
=
rval
.
domain
domain
=
rval
.
domain
trials
=
rval
.
trials
trials
=
rval
.
trials
...
@@ -378,13 +424,26 @@ class HyperoptTuner(Tuner):
...
@@ -378,13 +424,26 @@ class HyperoptTuner(Tuner):
total_params : dict
total_params : dict
parameter suggestion
parameter suggestion
"""
"""
if
self
.
parallel
and
len
(
self
.
total_data
)
>
20
and
len
(
self
.
running_data
)
and
self
.
optimal_y
is
not
None
:
self
.
CL_rval
=
copy
.
deepcopy
(
self
.
rval
)
if
self
.
constant_liar_type
==
'mean'
:
_constant_liar_y
=
self
.
optimal_y
[
0
]
/
self
.
optimal_y
[
1
]
else
:
_constant_liar_y
=
self
.
optimal_y
for
_parameter_id
in
self
.
running_data
:
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
None
,
value
=
_constant_liar_y
,
constant_liar
=
True
)
rval
=
self
.
CL_rval
random_state
=
np
.
random
.
randint
(
2
**
31
-
1
)
else
:
rval
=
self
.
rval
rval
=
self
.
rval
random_state
=
rval
.
rstate
.
randint
(
2
**
31
-
1
)
trials
=
rval
.
trials
trials
=
rval
.
trials
algorithm
=
rval
.
algo
algorithm
=
rval
.
algo
new_ids
=
rval
.
trials
.
new_trial_ids
(
1
)
new_ids
=
rval
.
trials
.
new_trial_ids
(
1
)
rval
.
trials
.
refresh
()
rval
.
trials
.
refresh
()
random_state
=
rval
.
rstate
.
randint
(
2
**
31
-
1
)
if
random_search
:
if
random_search
:
new_trials
=
hp
.
rand
.
suggest
(
new_ids
,
rval
.
domain
,
trials
,
new_trials
=
hp
.
rand
.
suggest
(
new_ids
,
rval
.
domain
,
trials
,
random_state
)
random_state
)
...
...
src/webui/src/components/TrialsDetail.tsx
View file @
05913424
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
axios
from
'
axios
'
;
import
{
MANAGER_IP
}
from
'
../static/const
'
;
import
{
MANAGER_IP
}
from
'
../static/const
'
;
import
{
Row
,
Col
,
Tabs
,
Input
,
Select
,
Button
,
Icon
}
from
'
antd
'
;
import
{
Row
,
Col
,
Tabs
,
Select
,
Button
,
Icon
}
from
'
antd
'
;
const
Option
=
Select
.
Option
;
const
Option
=
Select
.
Option
;
import
{
TableObj
,
Parameters
}
from
'
../static/interface
'
;
import
{
TableObj
,
Parameters
,
ExperimentInfo
}
from
'
../static/interface
'
;
import
{
getFinal
}
from
'
../static/function
'
;
import
{
getFinal
}
from
'
../static/function
'
;
import
DefaultPoint
from
'
./trial-detail/DefaultMetricPoint
'
;
import
DefaultPoint
from
'
./trial-detail/DefaultMetricPoint
'
;
import
Duration
from
'
./trial-detail/Duration
'
;
import
Duration
from
'
./trial-detail/Duration
'
;
...
@@ -13,6 +13,7 @@ import Intermediate from './trial-detail/Intermeidate';
...
@@ -13,6 +13,7 @@ import Intermediate from './trial-detail/Intermeidate';
import
TableList
from
'
./trial-detail/TableList
'
;
import
TableList
from
'
./trial-detail/TableList
'
;
const
TabPane
=
Tabs
.
TabPane
;
const
TabPane
=
Tabs
.
TabPane
;
import
'
../static/style/trialsDetail.scss
'
;
import
'
../static/style/trialsDetail.scss
'
;
import
'
../static/style/search.scss
'
;
interface
TrialDetailState
{
interface
TrialDetailState
{
accSource
:
object
;
accSource
:
object
;
...
@@ -20,8 +21,6 @@ interface TrialDetailState {
...
@@ -20,8 +21,6 @@ interface TrialDetailState {
tableListSource
:
Array
<
TableObj
>
;
tableListSource
:
Array
<
TableObj
>
;
searchResultSource
:
Array
<
TableObj
>
;
searchResultSource
:
Array
<
TableObj
>
;
isHasSearch
:
boolean
;
isHasSearch
:
boolean
;
experimentStatus
:
string
;
experimentPlatform
:
string
;
experimentLogCollection
:
boolean
;
experimentLogCollection
:
boolean
;
entriesTable
:
number
;
// table components val
entriesTable
:
number
;
// table components val
entriesInSelect
:
string
;
entriesInSelect
:
string
;
...
@@ -31,6 +30,9 @@ interface TrialDetailState {
...
@@ -31,6 +30,9 @@ interface TrialDetailState {
hyperCounts
:
number
;
// user click the hyper-parameter counts
hyperCounts
:
number
;
// user click the hyper-parameter counts
durationCounts
:
number
;
durationCounts
:
number
;
intermediateCounts
:
number
;
intermediateCounts
:
number
;
experimentInfo
:
ExperimentInfo
;
searchFilter
:
string
;
searchPlaceHolder
:
string
;
}
}
interface
TrialsDetailProps
{
interface
TrialsDetailProps
{
...
@@ -46,6 +48,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -46,6 +48,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
public
interAllTableList
=
2
;
public
interAllTableList
=
2
;
public
tableList
:
TableList
|
null
;
public
tableList
:
TableList
|
null
;
public
searchInput
:
HTMLInputElement
|
null
;
private
titleOfacc
=
(
private
titleOfacc
=
(
<
Title1
text
=
"Default metric"
icon
=
"3.png"
/>
<
Title1
text
=
"Default metric"
icon
=
"3.png"
/>
...
@@ -74,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -74,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
accNodata
:
''
,
accNodata
:
''
,
tableListSource
:
[],
tableListSource
:
[],
searchResultSource
:
[],
searchResultSource
:
[],
experimentStatus
:
''
,
experimentPlatform
:
''
,
experimentLogCollection
:
false
,
experimentLogCollection
:
false
,
entriesTable
:
20
,
entriesTable
:
20
,
entriesInSelect
:
'
20
'
,
entriesInSelect
:
'
20
'
,
...
@@ -85,7 +86,13 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -85,7 +86,13 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
isMultiPhase
:
false
,
isMultiPhase
:
false
,
hyperCounts
:
0
,
hyperCounts
:
0
,
durationCounts
:
0
,
durationCounts
:
0
,
intermediateCounts
:
0
intermediateCounts
:
0
,
experimentInfo
:
{
platform
:
''
,
optimizeMode
:
'
maximize
'
},
searchFilter
:
'
id
'
,
searchPlaceHolder
:
'
Search by id
'
};
};
}
}
...
@@ -212,16 +219,34 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -212,16 +219,34 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}));
}));
}
}
}
else
{
}
else
{
const
{
tableListSource
}
=
this
.
state
;
const
{
tableListSource
,
searchFilter
}
=
this
.
state
;
const
searchResultList
:
Array
<
TableObj
>
=
[];
const
searchResultList
:
Array
<
TableObj
>
=
[];
Object
.
keys
(
tableListSource
).
map
(
key
=>
{
Object
.
keys
(
tableListSource
).
map
(
key
=>
{
const
item
=
tableListSource
[
key
];
const
item
=
tableListSource
[
key
];
if
(
item
.
sequenceId
.
toString
()
===
targetValue
switch
(
searchFilter
)
{
||
item
.
id
.
includes
(
targetValue
)
case
'
id
'
:
||
item
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
())
if
(
item
.
id
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
)
{
searchResultList
.
push
(
item
);
}
break
;
case
'
Trial No.
'
:
if
(
item
.
sequenceId
.
toString
()
===
targetValue
)
{
searchResultList
.
push
(
item
);
}
break
;
case
'
status
'
:
if
(
item
.
status
.
toUpperCase
().
includes
(
targetValue
.
toUpperCase
()))
{
searchResultList
.
push
(
item
);
searchResultList
.
push
(
item
);
}
}
break
;
case
'
parameters
'
:
const
strParameters
=
JSON
.
stringify
(
item
.
description
.
parameters
,
null
,
4
);
if
(
strParameters
.
includes
(
targetValue
))
{
searchResultList
.
push
(
item
);
}
break
;
default
:
}
});
});
if
(
this
.
_isMounted
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
this
.
setState
(()
=>
({
...
@@ -282,6 +307,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -282,6 +307,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
alert
(
'
TableList component was not properly initialized.
'
);
alert
(
'
TableList component was not properly initialized.
'
);
}
}
getSearchFilter
=
(
value
:
string
)
=>
{
// clear input value and re-render table
if
(
this
.
searchInput
!==
null
)
{
this
.
searchInput
.
value
=
''
;
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
isHasSearch
:
false
}));
}
}
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
searchFilter
:
value
,
searchPlaceHolder
:
`Search by
${
value
}
`
}));
}
}
// get and set logCollection val
// get and set logCollection val
checkExperimentPlatform
=
()
=>
{
checkExperimentPlatform
=
()
=>
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
axios
(
`
${
MANAGER_IP
}
/experiment`
,
{
...
@@ -289,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -289,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
})
})
.
then
(
res
=>
{
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
if
(
res
.
status
===
200
)
{
const
trainingPlatform
=
res
.
data
.
params
.
trainingServicePlatform
!==
undefined
const
trainingPlatform
:
string
=
res
.
data
.
params
.
trainingServicePlatform
!==
undefined
?
?
res
.
data
.
params
.
trainingServicePlatform
res
.
data
.
params
.
trainingServicePlatform
:
:
...
@@ -299,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -299,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
let
expLogCollection
:
boolean
=
false
;
let
expLogCollection
:
boolean
=
false
;
const
isMultiy
:
boolean
=
res
.
data
.
params
.
multiPhase
!==
undefined
const
isMultiy
:
boolean
=
res
.
data
.
params
.
multiPhase
!==
undefined
?
res
.
data
.
params
.
multiPhase
:
false
;
?
res
.
data
.
params
.
multiPhase
:
false
;
const
tuner
=
res
.
data
.
params
.
tuner
;
// I'll set optimize is maximize if user not set optimize
let
optimize
:
string
=
'
maximize
'
;
if
(
tuner
!==
undefined
)
{
if
(
tuner
.
classArgs
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
!==
undefined
)
{
if
(
tuner
.
classArgs
.
optimize_mode
===
'
minimize
'
)
{
optimize
=
'
minimize
'
;
}
}
}
}
if
(
logCollection
!==
undefined
&&
logCollection
!==
'
none
'
)
{
if
(
logCollection
!==
undefined
&&
logCollection
!==
'
none
'
)
{
expLogCollection
=
true
;
expLogCollection
=
true
;
}
}
if
(
this
.
_isMounted
)
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
this
.
setState
({
experiment
P
latform
:
trainingPlatform
,
experiment
Info
:
{
p
latform
:
trainingPlatform
,
optimizeMode
:
optimize
},
searchSpace
:
res
.
data
.
params
.
searchSpace
,
searchSpace
:
res
.
data
.
params
.
searchSpace
,
experimentLogCollection
:
expLogCollection
,
experimentLogCollection
:
expLogCollection
,
isMultiPhase
:
isMultiy
isMultiPhase
:
isMultiy
...
@@ -343,8 +393,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -343,8 +393,8 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
const
{
const
{
tableListSource
,
searchResultSource
,
isHasSearch
,
isMultiPhase
,
tableListSource
,
searchResultSource
,
isHasSearch
,
isMultiPhase
,
entriesTable
,
experiment
Platform
,
searchSpace
,
experimentLogCollection
,
entriesTable
,
experiment
Info
,
searchSpace
,
experimentLogCollection
,
whichGraph
whichGraph
,
searchPlaceHolder
}
=
this
.
state
;
}
=
this
.
state
;
const
source
=
isHasSearch
?
searchResultSource
:
tableListSource
;
const
source
=
isHasSearch
?
searchResultSource
:
tableListSource
;
return
(
return
(
...
@@ -354,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -354,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<
TabPane
tab
=
{
this
.
titleOfacc
}
key
=
"1"
>
<
TabPane
tab
=
{
this
.
titleOfacc
}
key
=
"1"
>
<
Row
className
=
"graph"
>
<
Row
className
=
"graph"
>
<
DefaultPoint
<
DefaultPoint
height
=
{
4
3
2
}
height
=
{
4
0
2
}
showSource
=
{
source
}
showSource
=
{
source
}
whichGraph
=
{
whichGraph
}
whichGraph
=
{
whichGraph
}
optimize
=
{
experimentInfo
.
optimizeMode
}
/>
/>
</
Row
>
</
Row
>
</
TabPane
>
</
TabPane
>
...
@@ -408,11 +459,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -408,11 +459,19 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
>
Compare
Compare
</
Button
>
</
Button
>
<
Input
<
Select
defaultValue
=
"id"
className
=
"filter"
onSelect
=
{
this
.
getSearchFilter
}
>
<
Option
value
=
"id"
>
Id
</
Option
>
<
Option
value
=
"Trial No."
>
Trial No.
</
Option
>
<
Option
value
=
"status"
>
Status
</
Option
>
<
Option
value
=
"parameters"
>
Parameters
</
Option
>
</
Select
>
<
input
type
=
"text"
type
=
"text"
placeholder
=
"Search by id, trial No. or status"
className
=
"search-input"
placeholder
=
{
searchPlaceHolder
}
onChange
=
{
this
.
searchTrial
}
onChange
=
{
this
.
searchTrial
}
style
=
{
{
width
:
230
,
marginLeft
:
6
}
}
style
=
{
{
width
:
230
}
}
ref
=
{
text
=>
(
this
.
searchInput
)
=
text
}
/>
/>
</
Col
>
</
Col
>
</
Row
>
</
Row
>
...
@@ -420,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -420,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
entries
=
{
entriesTable
}
entries
=
{
entriesTable
}
tableSource
=
{
source
}
tableSource
=
{
source
}
isMultiPhase
=
{
isMultiPhase
}
isMultiPhase
=
{
isMultiPhase
}
platform
=
{
experiment
P
latform
}
platform
=
{
experiment
Info
.
p
latform
}
updateList
=
{
this
.
getDetailSource
}
updateList
=
{
this
.
getDetailSource
}
logCollection
=
{
experimentLogCollection
}
logCollection
=
{
experimentLogCollection
}
ref
=
{
(
tabList
)
=>
this
.
tableList
=
tabList
}
ref
=
{
(
tabList
)
=>
this
.
tableList
=
tabList
}
...
...
src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
View file @
05913424
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
{
Switch
}
from
'
antd
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
filterByStatus
}
from
'
../../static/function
'
;
import
{
filterByStatus
}
from
'
../../static/function
'
;
import
{
TableObj
,
DetailAccurPoint
,
TooltipForAccuracy
}
from
'
../../static/interface
'
;
import
{
TableObj
,
DetailAccurPoint
,
TooltipForAccuracy
}
from
'
../../static/interface
'
;
...
@@ -10,32 +11,36 @@ interface DefaultPointProps {
...
@@ -10,32 +11,36 @@ interface DefaultPointProps {
showSource
:
Array
<
TableObj
>
;
showSource
:
Array
<
TableObj
>
;
height
:
number
;
height
:
number
;
whichGraph
:
string
;
whichGraph
:
string
;
optimize
:
string
;
}
}
interface
DefaultPointState
{
interface
DefaultPointState
{
defaultSource
:
object
;
defaultSource
:
object
;
accNodata
:
string
;
accNodata
:
string
;
succeedTrials
:
number
;
succeedTrials
:
number
;
isViewBestCurve
:
boolean
;
}
}
class
DefaultPoint
extends
React
.
Component
<
DefaultPointProps
,
DefaultPointState
>
{
class
DefaultPoint
extends
React
.
Component
<
DefaultPointProps
,
DefaultPointState
>
{
public
_isMounted
=
false
;
public
_is
Default
Mounted
=
false
;
constructor
(
props
:
DefaultPointProps
)
{
constructor
(
props
:
DefaultPointProps
)
{
super
(
props
);
super
(
props
);
this
.
state
=
{
this
.
state
=
{
defaultSource
:
{},
defaultSource
:
{},
accNodata
:
''
,
accNodata
:
''
,
succeedTrials
:
10000000
succeedTrials
:
10000000
,
isViewBestCurve
:
false
};
};
}
}
defaultMetric
=
(
succeedSource
:
Array
<
TableObj
>
)
=>
{
defaultMetric
=
(
succeedSource
:
Array
<
TableObj
>
,
isCurve
:
boolean
)
=>
{
const
{
optimize
}
=
this
.
props
;
const
accSource
:
Array
<
DetailAccurPoint
>
=
[];
const
accSource
:
Array
<
DetailAccurPoint
>
=
[];
const
showSource
:
Array
<
TableObj
>
=
succeedSource
.
filter
(
filterByStatus
);
const
showSource
:
Array
<
TableObj
>
=
succeedSource
.
filter
(
filterByStatus
);
const
lengthOfSource
=
showSource
.
length
;
const
lengthOfSource
=
showSource
.
length
;
const
tooltipDefault
=
lengthOfSource
===
0
?
'
No data
'
:
''
;
const
tooltipDefault
=
lengthOfSource
===
0
?
'
No data
'
:
''
;
if
(
this
.
_isMounted
===
true
)
{
if
(
this
.
_is
Default
Mounted
===
true
)
{
this
.
setState
(()
=>
({
this
.
setState
(()
=>
({
succeedTrials
:
lengthOfSource
,
succeedTrials
:
lengthOfSource
,
accNodata
:
tooltipDefault
accNodata
:
tooltipDefault
...
@@ -55,34 +60,125 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -55,34 +60,125 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
type
:
'
value
'
,
type
:
'
value
'
,
}
}
};
};
if
(
this
.
_isMounted
===
true
)
{
if
(
this
.
_is
Default
Mounted
===
true
)
{
this
.
setState
(()
=>
({
this
.
setState
(()
=>
({
defaultSource
:
nullGraph
defaultSource
:
nullGraph
}));
}));
}
}
}
else
{
}
else
{
const
resultList
:
Array
<
number
|
string
>
[]
=
[];
const
resultList
:
Array
<
number
|
object
>
[]
=
[];
const
lineListDefault
:
Array
<
number
>
=
[];
Object
.
keys
(
showSource
).
map
(
item
=>
{
Object
.
keys
(
showSource
).
map
(
item
=>
{
const
temp
=
showSource
[
item
];
const
temp
=
showSource
[
item
];
if
(
temp
.
acc
!==
undefined
)
{
if
(
temp
.
acc
!==
undefined
)
{
if
(
temp
.
acc
.
default
!==
undefined
)
{
if
(
temp
.
acc
.
default
!==
undefined
)
{
const
searchSpace
=
temp
.
description
.
parameters
;
const
searchSpace
=
temp
.
description
.
parameters
;
lineListDefault
.
push
(
temp
.
acc
.
default
);
accSource
.
push
({
accSource
.
push
({
acc
:
temp
.
acc
.
default
,
acc
:
temp
.
acc
.
default
,
index
:
temp
.
sequenceId
,
index
:
temp
.
sequenceId
,
searchSpace
:
JSON
.
stringify
(
searchSpace
)
searchSpace
:
searchSpace
});
});
}
}
}
}
});
});
// deal with best metric line
const
bestCurve
:
Array
<
number
|
object
>
[]
=
[];
// best curve data source
bestCurve
.
push
([
0
,
lineListDefault
[
0
],
accSource
[
0
].
searchSpace
]);
// push the first value
if
(
optimize
===
'
maximize
'
)
{
for
(
let
i
=
1
;
i
<
lineListDefault
.
length
;
i
++
)
{
const
val
=
lineListDefault
[
i
];
const
latest
=
bestCurve
[
bestCurve
.
length
-
1
][
1
];
if
(
val
>=
latest
)
{
bestCurve
.
push
([
i
,
val
,
accSource
[
i
].
searchSpace
]);
}
else
{
bestCurve
.
push
([
i
,
latest
,
accSource
[
i
].
searchSpace
]);
}
}
}
else
{
for
(
let
i
=
1
;
i
<
lineListDefault
.
length
;
i
++
)
{
const
val
=
lineListDefault
[
i
];
const
latest
=
bestCurve
[
bestCurve
.
length
-
1
][
1
];
if
(
val
<=
latest
)
{
bestCurve
.
push
([
i
,
val
,
accSource
[
i
].
searchSpace
]);
}
else
{
bestCurve
.
push
([
i
,
latest
,
accSource
[
i
].
searchSpace
]);
}
}
}
Object
.
keys
(
accSource
).
map
(
item
=>
{
Object
.
keys
(
accSource
).
map
(
item
=>
{
const
items
=
accSource
[
item
];
const
items
=
accSource
[
item
];
let
temp
:
Array
<
number
|
string
>
;
let
temp
:
Array
<
number
|
object
>
;
temp
=
[
items
.
index
,
items
.
acc
,
JSON
.
parse
(
items
.
searchSpace
)
];
temp
=
[
items
.
index
,
items
.
acc
,
items
.
searchSpace
];
resultList
.
push
(
temp
);
resultList
.
push
(
temp
);
});
});
// isViewBestCurve: false show default metric graph
// isViewBestCurve: true show best curve
if
(
isCurve
===
true
)
{
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
this
.
drawBestcurve
(
bestCurve
,
resultList
)
}));
}
}
else
{
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
this
.
drawDefaultMetric
(
resultList
)
}));
}
}
}
}
const
allAcuracy
=
{
drawBestcurve
=
(
realDefault
:
Array
<
number
|
object
>
[],
resultList
:
Array
<
number
|
object
>
[])
=>
{
return
{
grid
:
{
left
:
'
8%
'
},
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForAccuracy
)
{
if
(
data
.
data
[
0
]
<
realDefault
.
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
formatter
:
function
(
data
:
TooltipForAccuracy
)
{
const
result
=
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial No.:
'
+
data
.
data
[
0
]
+
'
</div>
'
+
'
<div>Optimization curve:
'
+
data
.
data
[
1
]
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
data
.
data
[
2
],
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
return
result
;
}
},
xAxis
:
{
name
:
'
Trial
'
,
type
:
'
category
'
,
},
yAxis
:
{
name
:
'
Default metric
'
,
type
:
'
value
'
,
scale
:
true
},
series
:
[{
symbolSize
:
6
,
type
:
'
scatter
'
,
data
:
resultList
},
{
type
:
'
line
'
,
lineStyle
:
{
color
:
'
#FF6600
'
},
data
:
realDefault
}]
};
}
drawDefaultMetric
=
(
resultList
:
Array
<
number
|
object
>
[])
=>
{
return
{
grid
:
{
grid
:
{
left
:
'
8%
'
left
:
'
8%
'
},
},
...
@@ -114,6 +210,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -114,6 +210,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
yAxis
:
{
yAxis
:
{
name
:
'
Default metric
'
,
name
:
'
Default metric
'
,
type
:
'
value
'
,
type
:
'
value
'
,
scale
:
true
},
},
series
:
[{
series
:
[{
symbolSize
:
6
,
symbolSize
:
6
,
...
@@ -121,11 +218,15 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -121,11 +218,15 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
data
:
resultList
data
:
resultList
}]
}]
};
};
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
defaultSource
:
allAcuracy
}));
}
}
loadDefault
=
(
checked
:
boolean
)
=>
{
// checked: true show best metric curve
const
{
showSource
}
=
this
.
props
;
if
(
this
.
_isDefaultMounted
===
true
)
{
this
.
defaultMetric
(
showSource
,
checked
);
// ** deal with data and then update view layer
this
.
setState
(()
=>
({
isViewBestCurve
:
checked
}));
}
}
}
}
...
@@ -133,16 +234,21 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -133,16 +234,21 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
componentWillReceiveProps
(
nextProps
:
DefaultPointProps
)
{
componentWillReceiveProps
(
nextProps
:
DefaultPointProps
)
{
const
{
whichGraph
,
showSource
}
=
nextProps
;
const
{
whichGraph
,
showSource
}
=
nextProps
;
const
{
isViewBestCurve
}
=
this
.
state
;
if
(
whichGraph
===
'
1
'
)
{
if
(
whichGraph
===
'
1
'
)
{
this
.
defaultMetric
(
showSource
);
this
.
defaultMetric
(
showSource
,
isViewBestCurve
);
}
}
}
}
shouldComponentUpdate
(
nextProps
:
DefaultPointProps
,
nextState
:
DefaultPointState
)
{
shouldComponentUpdate
(
nextProps
:
DefaultPointProps
,
nextState
:
DefaultPointState
)
{
const
{
whichGraph
}
=
nextProps
;
const
{
whichGraph
}
=
nextProps
;
const
succTrial
=
this
.
state
.
succeedTrials
;
const
{
succeedTrials
}
=
nextState
;
if
(
whichGraph
===
'
1
'
)
{
if
(
whichGraph
===
'
1
'
)
{
const
{
succeedTrials
,
isViewBestCurve
}
=
nextState
;
const
succTrial
=
this
.
state
.
succeedTrials
;
const
isViewBestCurveBefore
=
this
.
state
.
isViewBestCurve
;
if
(
isViewBestCurveBefore
!==
isViewBestCurve
)
{
return
true
;
}
if
(
succeedTrials
!==
succTrial
)
{
if
(
succeedTrials
!==
succTrial
)
{
return
true
;
return
true
;
}
}
...
@@ -152,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -152,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
}
componentDidMount
()
{
componentDidMount
()
{
this
.
_isMounted
=
true
;
this
.
_is
Default
Mounted
=
true
;
}
}
componentWillUnmount
()
{
componentWillUnmount
()
{
this
.
_isMounted
=
false
;
this
.
_is
Default
Mounted
=
false
;
}
}
render
()
{
render
()
{
...
@@ -164,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -164,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
const
{
defaultSource
,
accNodata
}
=
this
.
state
;
const
{
defaultSource
,
accNodata
}
=
this
.
state
;
return
(
return
(
<
div
>
<
div
>
<
div
className
=
"default-metric"
>
<
div
className
=
"position"
>
<
span
className
=
"bold"
>
optimization curve
</
span
>
<
Switch
defaultChecked
=
{
false
}
onChange
=
{
this
.
loadDefault
}
/>
</
div
>
</
div
>
<
ReactEcharts
<
ReactEcharts
option
=
{
defaultSource
}
option
=
{
defaultSource
}
style
=
{
{
style
=
{
{
...
@@ -173,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
...
@@ -173,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}
}
}
}
theme
=
"my_theme"
theme
=
"my_theme"
notMerge
=
{
true
}
// update now
notMerge
=
{
true
}
// update now
// lazyUpdate={true}
/>
/>
<
div
className
=
"showMess"
>
{
accNodata
}
</
div
>
<
div
className
=
"showMess"
>
{
accNodata
}
</
div
>
</
div
>
</
div
>
...
...
src/webui/src/components/trial-detail/Intermeidate.tsx
View file @
05913424
...
@@ -114,7 +114,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -114,7 +114,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
},
yAxis
:
{
yAxis
:
{
type
:
'
value
'
,
type
:
'
value
'
,
name
:
'
m
etric
'
name
:
'
M
etric
'
},
},
series
:
trialIntermediate
series
:
trialIntermediate
};
};
...
@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
},
},
yAxis
:
{
yAxis
:
{
type
:
'
value
'
,
type
:
'
value
'
,
name
:
'
m
etric
'
name
:
'
M
etric
'
}
}
};
};
if
(
this
.
_isMounted
)
{
if
(
this
.
_isMounted
)
{
...
@@ -283,9 +283,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
...
@@ -283,9 +283,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
{
/* style in para.scss */
}
{
/* style in para.scss */
}
<
Row
className
=
"meline intermediate"
>
<
Row
className
=
"meline intermediate"
>
<
Col
span
=
{
8
}
/>
<
Col
span
=
{
8
}
/>
<
Col
span
=
{
3
}
style
=
{
{
height
:
34
}
}
>
<
Col
span
=
{
3
}
className
=
"inter-filter-btn"
>
{
/* filter message */
}
{
/* filter message */
}
<
span
>
f
ilter
</
span
>
<
span
>
F
ilter
</
span
>
<
Switch
<
Switch
defaultChecked
=
{
false
}
defaultChecked
=
{
false
}
onChange
=
{
this
.
switchTurn
}
onChange
=
{
this
.
switchTurn
}
...
...
src/webui/src/components/trial-detail/Para.tsx
View file @
05913424
...
@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> {
...
@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> {
let
temp
:
Array
<
number
>
=
[];
let
temp
:
Array
<
number
>
=
[];
for
(
let
i
=
0
;
i
<
dimName
.
length
;
i
++
)
{
for
(
let
i
=
0
;
i
<
dimName
.
length
;
i
++
)
{
if
(
'
type
'
in
parallelAxis
[
i
])
{
if
(
'
type
'
in
parallelAxis
[
i
])
{
temp
.
push
(
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]].
toString
());
eachTrialParams
[
item
][
dimName
[
i
]].
toString
()
);
}
else
{
}
else
{
temp
.
push
(
// default metric
eachTrialParams
[
item
][
dimName
[
i
]]
temp
.
push
(
eachTrialParams
[
item
][
dimName
[
i
]]);
);
}
}
}
}
paraYdata
.
push
(
temp
);
paraYdata
.
push
(
temp
);
...
@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> {
...
@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> {
break
;
break
;
// support log distribute
// support log distribute
case
'
loguniform
'
:
case
'
loguniform
'
:
if
(
lenOfDataSource
>
1
)
{
parallelAxis
.
push
({
parallelAxis
.
push
({
dim
:
i
,
dim
:
i
,
name
:
dimName
[
i
],
name
:
dimName
[
i
],
type
:
'
log
'
,
type
:
'
log
'
,
});
});
}
else
{
parallelAxis
.
push
({
dim
:
i
,
name
:
dimName
[
i
]
});
}
break
;
break
;
default
:
default
:
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
05913424
...
@@ -321,9 +321,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
...
@@ -321,9 +321,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key
:
'
sequenceId
'
,
key
:
'
sequenceId
'
,
width
:
120
,
width
:
120
,
className
:
'
tableHead
'
,
className
:
'
tableHead
'
,
sorter
:
sorter
:
(
a
:
TableObj
,
b
:
TableObj
)
=>
(
a
.
sequenceId
as
number
)
-
(
b
.
sequenceId
as
number
)
(
a
:
TableObj
,
b
:
TableObj
)
=>
(
a
.
sequenceId
as
number
)
-
(
b
.
sequenceId
as
number
)
});
});
break
;
break
;
case
'
ID
'
:
case
'
ID
'
:
...
...
src/webui/src/static/interface.ts
View file @
05913424
...
@@ -59,7 +59,7 @@ interface AccurPoint {
...
@@ -59,7 +59,7 @@ interface AccurPoint {
interface
DetailAccurPoint
{
interface
DetailAccurPoint
{
acc
:
number
;
acc
:
number
;
index
:
number
;
index
:
number
;
searchSpace
:
string
;
searchSpace
:
object
;
}
}
interface
TooltipForIntermediate
{
interface
TooltipForIntermediate
{
...
@@ -117,8 +117,13 @@ interface Intermedia {
...
@@ -117,8 +117,13 @@ interface Intermedia {
hyperPara
:
object
;
// each trial hyperpara value
hyperPara
:
object
;
// each trial hyperpara value
}
}
interface
ExperimentInfo
{
platform
:
string
;
optimizeMode
:
string
;
}
export
{
export
{
TableObj
,
Parameters
,
Experiment
,
AccurPoint
,
TrialNumber
,
TrialJob
,
TableObj
,
Parameters
,
Experiment
,
AccurPoint
,
TrialNumber
,
TrialJob
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalResult
,
FinalType
,
DetailAccurPoint
,
TooltipForAccuracy
,
ParaObj
,
Dimobj
,
FinalResult
,
FinalType
,
TooltipForIntermediate
,
SearchSpace
,
Intermedia
TooltipForIntermediate
,
SearchSpace
,
Intermedia
,
ExperimentInfo
};
};
src/webui/src/static/style/para.scss
View file @
05913424
...
@@ -36,6 +36,10 @@
...
@@ -36,6 +36,10 @@
.strange
{
.strange
{
margin-top
:
2px
;
margin-top
:
2px
;
}
}
.inter-filter-btn
{
height
:
34px
;
line-height
:
34px
;
}
.range
{
.range
{
.heng
{
.heng
{
margin-left
:
6px
;
margin-left
:
6px
;
...
...
src/webui/src/static/style/search.scss
View file @
05913424
...
@@ -11,6 +11,24 @@
...
@@ -11,6 +11,24 @@
color
:
#0071BC
;
color
:
#0071BC
;
border-radius
:
0
;
border-radius
:
0
;
}
}
.filter
{
width
:
100px
;
margin-left
:
8px
;
.ant-select-selection-selected-value
{
font-size
:
14px
;
}
}
.search-input
{
height
:
32px
;
outline
:
none
;
border
:
1px
solid
#d9d9d9
;
border-left
:
none
;
padding-left
:
8px
;
color
:
#333
;
}
.
search-input
:
:
placeholder
{
color
:
DarkGray
;
}
}
}
.entry
{
.entry
{
width
:
120px
;
width
:
120px
;
...
...
src/webui/src/static/style/table.scss
View file @
05913424
...
@@ -31,14 +31,12 @@
...
@@ -31,14 +31,12 @@
text-align
:
center
;
text-align
:
center
;
color
:
#212121
;
color
:
#212121
;
font-size
:
14px
;
font-size
:
14px
;
/* background-color: #f2f2f2; */
}
}
th
{
th
{
padding
:
2px
;
padding
:
2px
;
background-color
:white
!
important
;
background-color
:white
!
important
;
font-size
:
14px
;
font-size
:
14px
;
color
:
#808080
;
color
:
#808080
;
border-bottom
:
1px
solid
#d0d0d0
;
text-align
:
center
;
text-align
:
center
;
}
}
...
@@ -105,3 +103,9 @@
...
@@ -105,3 +103,9 @@
.ant-table-selection
{
.ant-table-selection
{
display
:
none
;
display
:
none
;
}
}
/* fix the border-bottom bug in firefox and edge */
.
ant-table-thead
>
tr
>
th
.
ant-table-column-sorters
:
:
before
{
padding-bottom
:
25px
;
border-bottom
:
1px
solid
#e8e8e8
;
}
\ No newline at end of file
src/webui/src/static/style/trialsDetail.scss
View file @
05913424
...
@@ -70,3 +70,16 @@
...
@@ -70,3 +70,16 @@
.allList
{
.allList
{
margin-top
:
15px
;
margin-top
:
15px
;
}
}
.default-metric
{
width
:
90%
;
text-align
:
right
;
margin-top
:
15px
;
.position
{
color
:
#333
;
.bold
{
font-weight
:
600
;
margin-right
:
10px
;
}
}
}
src/webui/yarn.lock
View file @
05913424
This diff is collapsed.
Click to expand it.
test/cli_test.py
0 → 100644
View file @
05913424
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
sys
import
time
import
traceback
from
utils
import
GREEN
,
RED
,
CLEAR
,
setup_experiment
def
test_nni_cli
():
import
nnicli
as
nc
config_file
=
'config_test/examples/mnist.test.yml'
try
:
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time
.
sleep
(
6
)
print
(
GREEN
+
'Testing nnicli:'
+
config_file
+
CLEAR
)
nc
.
start_nni
(
config_file
)
time
.
sleep
(
3
)
nc
.
set_endpoint
(
'http://localhost:8080'
)
print
(
nc
.
version
())
print
(
nc
.
get_job_statistics
())
print
(
nc
.
get_experiment_status
())
nc
.
list_trial_jobs
()
print
(
GREEN
+
'Test nnicli {}: TEST PASS'
.
format
(
config_file
)
+
CLEAR
)
except
Exception
as
error
:
print
(
RED
+
'Test nnicli {}: TEST FAIL'
.
format
(
config_file
)
+
CLEAR
)
print
(
'%r'
%
error
)
traceback
.
print_exc
()
raise
error
finally
:
nc
.
stop_nni
()
if
__name__
==
'__main__'
:
installed
=
(
sys
.
argv
[
-
1
]
!=
'--preinstall'
)
setup_experiment
(
installed
)
test_nni_cli
()
test/naive_test.py
View file @
05913424
...
@@ -88,6 +88,7 @@ def stop_experiment_test():
...
@@ -88,6 +88,7 @@ def stop_experiment_test():
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8080'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8080'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8888'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8888'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8989'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8989'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8990'
],
check
=
True
)
# test cmd 'nnictl stop id`
# test cmd 'nnictl stop id`
experiment_id
=
get_experiment_id
(
EXPERIMENT_URL
)
experiment_id
=
get_experiment_id
(
EXPERIMENT_URL
)
...
@@ -96,6 +97,12 @@ def stop_experiment_test():
...
@@ -96,6 +97,12 @@ def stop_experiment_test():
snooze
()
snooze
()
assert
not
detect_port
(
8080
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
assert
not
detect_port
(
8080
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
# test cmd `nnictl stop --port`
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'--port'
,
'8990'
])
assert
proc
.
returncode
==
0
,
'`nnictl stop %s` failed with code %d'
%
(
experiment_id
,
proc
.
returncode
)
snooze
()
assert
not
detect_port
(
8990
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
# test cmd `nnictl stop all`
# test cmd `nnictl stop all`
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'all'
])
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'all'
])
assert
proc
.
returncode
==
0
,
'`nnictl stop all` failed with code %d'
%
proc
.
returncode
assert
proc
.
returncode
==
0
,
'`nnictl stop all` failed with code %d'
%
proc
.
returncode
...
...
test/pipelines-it-local-windows.yml
View file @
05913424
...
@@ -36,3 +36,7 @@ jobs:
...
@@ -36,3 +36,7 @@ jobs:
cd test
cd test
python metrics_test.py
python metrics_test.py
displayName
:
'
Trial
job
metrics
test'
displayName
:
'
Trial
job
metrics
test'
-
script
:
|
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName
:
'
nnicli
test'
test/pipelines-it-local.yml
View file @
05913424
...
@@ -37,3 +37,7 @@ jobs:
...
@@ -37,3 +37,7 @@ jobs:
cd test
cd test
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName
:
'
Trial
job
metrics
test'
displayName
:
'
Trial
job
metrics
test'
-
script
:
|
cd test
PATH=$HOME/.local/bin:$PATH python3 cli_test.py
displayName
:
'
nnicli
test'
tools/nni_annotation/search_space_generator.py
View file @
05913424
...
@@ -57,8 +57,8 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -57,8 +57,8 @@ class SearchSpaceGenerator(ast.NodeTransformer):
key
=
self
.
module_name
+
'/'
+
mutable_block
key
=
self
.
module_name
+
'/'
+
mutable_block
args
[
0
].
s
=
key
args
[
0
].
s
=
key
if
key
not
in
self
.
search_space
:
if
key
not
in
self
.
search_space
:
self
.
search_space
[
key
]
=
dict
()
self
.
search_space
[
key
]
=
{
'_type'
:
'mutable_layer'
,
'_value'
:
{}}
self
.
search_space
[
key
][
mutable_layer
]
=
{
self
.
search_space
[
key
][
'_value'
][
mutable_layer
]
=
{
'layer_choice'
:
[
k
.
s
for
k
in
args
[
2
].
keys
],
'layer_choice'
:
[
k
.
s
for
k
in
args
[
2
].
keys
],
'optional_inputs'
:
[
k
.
s
for
k
in
args
[
5
].
keys
],
'optional_inputs'
:
[
k
.
s
for
k
in
args
[
5
].
keys
],
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
...
...
tools/nni_annotation/test_annotation.py
View file @
05913424
...
@@ -44,8 +44,9 @@ class AnnotationTestCase(TestCase):
...
@@ -44,8 +44,9 @@ class AnnotationTestCase(TestCase):
self
.
assertEqual
(
search_space
,
json
.
load
(
f
))
self
.
assertEqual
(
search_space
,
json
.
load
(
f
))
def
test_code_generator
(
self
):
def
test_code_generator
(
self
):
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated'
)
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated'
,
nas_mode
=
'classic_mode'
)
self
.
assertEqual
(
code_dir
,
'_generated'
)
self
.
assertEqual
(
code_dir
,
'_generated'
)
self
.
_assert_source_equal
(
'testcase/annotated/nas.py'
,
'_generated/nas.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/mnist.py'
,
'_generated/mnist.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/mnist.py'
,
'_generated/mnist.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/dir/simple.py'
,
'_generated/dir/simple.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/dir/simple.py'
,
'_generated/dir/simple.py'
)
with
open
(
'testcase/usercode/nonpy.txt'
)
as
src
,
open
(
'_generated/nonpy.txt'
)
as
dst
:
with
open
(
'testcase/usercode/nonpy.txt'
)
as
src
,
open
(
'_generated/nonpy.txt'
)
as
dst
:
...
...
tools/nni_annotation/testcase/annotated/nas.py
0 → 100644
View file @
05913424
import
nni
import
time
def
add_one
(
inputs
):
return
inputs
+
1
def
add_two
(
inputs
):
return
inputs
+
2
def
add_three
(
inputs
):
return
inputs
+
3
def
add_four
(
inputs
):
return
inputs
+
4
def
main
():
images
=
5
layer_1_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_0'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'images'
:
images
},
1
,
'classic_mode'
)
layer_2_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_1'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'layer_1_out'
:
layer_1_out
},
1
,
'classic_mode'
)
layer_3_out
=
nni
.
mutable_layer
(
'mutable_block_39'
,
'mutable_layer_2'
,
{
'add_one()'
:
add_one
,
'add_two()'
:
add_two
,
'add_three()'
:
add_three
,
'add_four()'
:
add_four
},
{
'add_one()'
:
{},
'add_two()'
:
{},
'add_three()'
:
{},
'add_four()'
:
{}},
[],
{
'layer_1_out'
:
layer_1_out
,
'layer_2_out'
:
layer_2_out
},
1
,
'classic_mode'
)
nni
.
report_intermediate_result
(
layer_1_out
)
time
.
sleep
(
2
)
nni
.
report_intermediate_result
(
layer_2_out
)
time
.
sleep
(
2
)
nni
.
report_intermediate_result
(
layer_3_out
)
time
.
sleep
(
2
)
layer_3_out
=
layer_3_out
+
10
nni
.
report_final_result
(
layer_3_out
)
if
__name__
==
'__main__'
:
main
()
tools/nni_annotation/testcase/searchspace.json
View file @
05913424
...
@@ -143,5 +143,47 @@
...
@@ -143,5 +143,47 @@
"(2 * 3 + 4)"
,
"(2 * 3 + 4)"
,
"(lambda x: 1 + x)"
"(lambda x: 1 + x)"
]
]
},
"nas/mutable_block_39"
:
{
"_type"
:
"mutable_layer"
,
"_value"
:
{
"mutable_layer_0"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"images"
],
"optional_input_size"
:
1
},
"mutable_layer_1"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"layer_1_out"
],
"optional_input_size"
:
1
},
"mutable_layer_2"
:
{
"layer_choice"
:
[
"add_one()"
,
"add_two()"
,
"add_three()"
,
"add_four()"
],
"optional_inputs"
:
[
"layer_1_out"
,
"layer_2_out"
],
"optional_input_size"
:
1
}
}
}
}
}
}
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment