Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1500458a
Unverified
Commit
1500458a
authored
Jun 24, 2019
by
SparkSnail
Committed by
GitHub
Jun 24, 2019
Browse files
Merge pull request #187 from microsoft/master
merge master
parents
93dd76ba
97829ccd
Changes
57
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
291 additions
and
485 deletions
+291
-485
src/nni_manager/types/tail-stream/index.d.ts
src/nni_manager/types/tail-stream/index.d.ts
+2
-1
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+4
-6
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+2
-2
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+8
-0
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
+2
-2
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
+2
-2
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+2
-2
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+2
-2
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+23
-7
src/sdk/pynni/nni/multi_phase/__init__.py
src/sdk/pynni/nni/multi_phase/__init__.py
+0
-0
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+0
-198
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
+0
-106
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
.../pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
+2
-2
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
+3
-3
src/sdk/pynni/nni/tuner.py
src/sdk/pynni/nni/tuner.py
+6
-6
src/sdk/pynni/tests/test_multi_phase_tuner.py
src/sdk/pynni/tests/test_multi_phase_tuner.py
+0
-110
src/sdk/pynni/tests/test_tuner.py
src/sdk/pynni/tests/test_tuner.py
+5
-7
src/webui/src/components/Modal/Compare.tsx
src/webui/src/components/Modal/Compare.tsx
+204
-0
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+23
-21
src/webui/src/components/trial-detail/Intermeidate.tsx
src/webui/src/components/trial-detail/Intermeidate.tsx
+1
-8
No files found.
src/nni_manager/types/tail-stream/index.d.ts
View file @
1500458a
declare
module
'
tail-stream
'
{
declare
module
'
tail-stream
'
{
export
interface
Stream
{
export
interface
Stream
{
on
(
type
:
'
data
'
,
callback
:
(
data
:
Buffer
)
=>
void
):
void
;
on
(
type
:
'
data
'
,
callback
:
(
data
:
Buffer
)
=>
void
):
void
;
destroy
():
void
;
end
(
data
:
number
):
void
;
emit
(
data
:
string
):
void
;
}
}
export
function
createReadStream
(
path
:
string
):
Stream
;
export
function
createReadStream
(
path
:
string
):
Stream
;
}
}
\ No newline at end of file
src/sdk/pynni/nni/__main__.py
View file @
1500458a
...
@@ -28,9 +28,8 @@ import json
...
@@ -28,9 +28,8 @@ import json
import
importlib
import
importlib
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
,
AdvisorModuleName
,
AdvisorClassName
from
.constants
import
ModuleName
,
ClassName
,
ClassArgs
,
AdvisorModuleName
,
AdvisorClassName
from
nni.common
import
enable_multi_thread
from
nni.common
import
enable_multi_thread
,
enable_multi_phase
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.msg_dispatcher
import
MsgDispatcher
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
.
debug
(
'START'
)
logger
.
debug
(
'START'
)
...
@@ -126,6 +125,8 @@ def main():
...
@@ -126,6 +125,8 @@ def main():
args
=
parse_args
()
args
=
parse_args
()
if
args
.
multi_thread
:
if
args
.
multi_thread
:
enable_multi_thread
()
enable_multi_thread
()
if
args
.
multi_phase
:
enable_multi_phase
()
if
args
.
advisor_class_name
:
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
# advisor is enabled and starts to run
...
@@ -180,10 +181,7 @@ def main():
...
@@ -180,10 +181,7 @@ def main():
if
assessor
is
None
:
if
assessor
is
None
:
raise
AssertionError
(
'Failed to create Assessor instance'
)
raise
AssertionError
(
'Failed to create Assessor instance'
)
if
args
.
multi_phase
:
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
,
assessor
)
else
:
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
try
:
try
:
dispatcher
.
run
()
dispatcher
.
run
()
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
1500458a
...
@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
...
@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
"""
self
.
values
=
self
.
is_valid
(
search_space
)
self
.
values
=
self
.
is_valid
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
Parameters
...
@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
...
@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
return
self
.
values
[
self
.
count
]
return
self
.
values
[
self
.
count
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
pass
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/common.py
View file @
1500458a
...
@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
...
@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
_multi_thread
=
False
_multi_thread
=
False
_multi_phase
=
False
def
enable_multi_thread
():
def
enable_multi_thread
():
global
_multi_thread
global
_multi_thread
...
@@ -76,3 +77,10 @@ def enable_multi_thread():
...
@@ -76,3 +77,10 @@ def enable_multi_thread():
def
multi_thread_enabled
():
def
multi_thread_enabled
():
return
_multi_thread
return
_multi_thread
def
enable_multi_phase
():
global
_multi_phase
_multi_phase
=
True
def
multi_phase_enabled
():
return
_multi_phase
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
View file @
1500458a
...
@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
...
@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
searchspace_json
,
is_rand
,
self
.
random_state
)
self
.
population
.
append
(
Individual
(
config
=
config
))
self
.
population
.
append
(
Individual
(
config
=
config
))
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
Parameters
...
@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
...
@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config
=
split_index
(
total_config
)
config
=
split_index
(
total_config
)
return
config
return
config
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
'''Record the result from a trial
'''Record the result from a trial
Parameters
Parameters
...
...
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
View file @
1500458a
...
@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
...
@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
'''
self
.
expanded_search_space
=
self
.
json2parameter
(
search_space
)
self
.
expanded_search_space
=
self
.
json2parameter
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
self
.
count
+=
1
self
.
count
+=
1
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
...
@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
...
@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return
self
.
expanded_search_space
[
self
.
count
]
return
self
.
expanded_search_space
[
self
.
count
]
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
pass
pass
def
import_data
(
self
,
data
):
def
import_data
(
self
,
data
):
...
...
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
1500458a
...
@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
...
@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose
=
0
)
verbose
=
0
)
self
.
rval
.
catch_eval_exceptions
=
False
self
.
rval
.
catch_eval_exceptions
=
False
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
Returns a set of trial (hyper-)parameters, as a serializable object.
...
@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
...
@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params
=
split_index
(
total_params
)
params
=
split_index
(
total_params
)
return
params
return
params
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""
"""
Record an observation of the objective function
Record an observation of the objective function
...
...
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
1500458a
...
@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
...
@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
return
output
return
output
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Generate next parameter for trial
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
metis will first random generate some parameters.
...
@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
...
@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
return
results
return
results
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Tuner receive result from trial.
"""Tuner receive result from trial.
Parameters
Parameters
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
1500458a
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# ==================================================================================================
import
os
import
logging
import
logging
from
collections
import
defaultdict
from
collections
import
defaultdict
import
json_tricks
import
json_tricks
...
@@ -26,7 +25,7 @@ import json_tricks
...
@@ -26,7 +25,7 @@ import json_tricks
from
.protocol
import
CommandType
,
send
from
.protocol
import
CommandType
,
send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.assessor
import
AssessResult
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.env_vars
import
dispatcher_env_vars
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -61,13 +60,19 @@ def _create_parameter_id():
...
@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id
+=
1
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
):
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
_trial_params
[
parameter_id
]
=
params
ret
=
{
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
'parameters'
:
params
}
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
return
json_tricks
.
dumps
(
ret
)
class
MsgDispatcher
(
MsgDispatcherBase
):
class
MsgDispatcher
(
MsgDispatcherBase
):
...
@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif
data
[
'type'
]
==
'PERIODICAL'
:
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
self
.
_handle_intermediate_metric_data
(
data
)
else
:
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
pass
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
=
data
[
'trial_job_id'
])
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
...
@@ -160,9 +170,15 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -160,9 +170,15 @@ class MsgDispatcher(MsgDispatcherBase):
id_
=
data
[
'parameter_id'
]
id_
=
data
[
'parameter_id'
]
value
=
data
[
'value'
]
value
=
data
[
'value'
]
if
id_
in
_customized_parameter_ids
:
if
id_
in
_customized_parameter_ids
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
if
multi_phase_enabled
():
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
else
:
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
if
multi_phase_enabled
():
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
def
_handle_intermediate_metric_data
(
self
,
data
):
def
_handle_intermediate_metric_data
(
self
,
data
):
"""Call assessor to process intermediate results
"""Call assessor to process intermediate results
...
...
src/sdk/pynni/nni/multi_phase/__init__.py
deleted
100644 → 0
View file @
93dd76ba
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
deleted
100644 → 0
View file @
93dd76ba
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
collections
import
defaultdict
import
json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.assessor
import
AssessResult
_logger
=
logging
.
getLogger
(
__name__
)
# Assessor global variables
_trial_history
=
defaultdict
(
dict
)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials
=
set
()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def
_sort_history
(
history
):
ret
=
[
]
for
i
,
_
in
enumerate
(
history
):
if
i
in
history
:
ret
.
append
(
history
[
i
])
else
:
break
return
ret
# Tuner global variables
_next_parameter_id
=
0
_trial_params
=
{}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids
=
set
()
def
_create_parameter_id
():
global
_next_parameter_id
# pylint: disable=global-statement
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
class
MultiPhaseMsgDispatcher
(
MsgDispatcherBase
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
super
(
MultiPhaseMsgDispatcher
,
self
).
__init__
()
self
.
tuner
=
tuner
self
.
assessor
=
assessor
if
assessor
is
None
:
_logger
.
debug
(
'Assessor is not configured'
)
def
load_checkpoint
(
self
):
self
.
tuner
.
load_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
load_checkpoint
()
def
save_checkpoint
(
self
):
self
.
tuner
.
save_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
save_checkpoint
()
def
handle_initialize
(
self
,
data
):
'''
data is search space
'''
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
params_list
=
self
.
tuner
.
generate_multiple_parameters
(
ids
)
assert
len
(
ids
)
==
len
(
params_list
)
for
i
,
_
in
enumerate
(
ids
):
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
ids
[
i
],
params_list
[
i
]))
return
True
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
return
True
def
handle_import_data
(
self
,
data
):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
return
True
def
handle_report_metric_data
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
if
data
[
'type'
]
==
'FINAL'
:
id_
=
data
[
'parameter_id'
]
if
id_
in
_customized_parameter_ids
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
else
:
pass
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
)
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
return
True
def
handle_trial_end
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
if
trial_job_id
in
_trial_history
:
_trial_history
.
pop
(
trial_job_id
)
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
,
trial_job_id
)
return
True
def
handle_import_data
(
self
,
data
):
pass
def
_handle_intermediate_metric_data
(
self
,
data
):
if
data
[
'type'
]
!=
'PERIODICAL'
:
return
True
if
self
.
assessor
is
None
:
return
True
trial_job_id
=
data
[
'trial_job_id'
]
if
trial_job_id
in
_ended_trials
:
return
True
history
=
_trial_history
[
trial_job_id
]
history
[
data
[
'sequence'
]]
=
data
[
'value'
]
ordered_history
=
_sort_history
(
history
)
if
len
(
ordered_history
)
<
data
[
'sequence'
]:
# no user-visible update since last time
return
True
try
:
result
=
self
.
assessor
.
assess_trial
(
trial_job_id
,
ordered_history
)
except
Exception
as
e
:
_logger
.
exception
(
'Assessor error'
)
if
isinstance
(
result
,
bool
):
result
=
AssessResult
.
Good
if
result
else
AssessResult
.
Bad
elif
not
isinstance
(
result
,
AssessResult
):
msg
=
'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise
RuntimeError
(
msg
%
type
(
result
))
if
result
is
AssessResult
.
Bad
:
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dumps
(
trial_job_id
))
else
:
_logger
.
debug
(
'GOOD'
)
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
deleted
100644 → 0
View file @
93dd76ba
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
nni.recoverable
import
Recoverable
_logger
=
logging
.
getLogger
(
__name__
)
class
MultiPhaseTuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return
[
self
.
generate_parameters
(
parameter_id
)
for
parameter_id
in
parameter_id_list
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
,
trial_job_id
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def
update_search_space
(
self
,
search_space
):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
load_checkpoint
(
self
):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Load checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
save_checkpoint
(
self
):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
_on_exit
(
self
):
pass
def
_on_error
(
self
):
pass
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
View file @
1500458a
...
@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
"""
"""
self
.
search_space
=
search_space
self
.
search_space
=
search_space
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
"""
Returns a set of trial neural architecture, as a serializable object.
Returns a set of trial neural architecture, as a serializable object.
...
@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
...
@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return
json_out
return
json_out
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
""" Record an observation of the objective function.
""" Record an observation of the objective function.
Parameters
Parameters
...
...
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
View file @
1500458a
...
@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
...
@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else
:
else
:
self
.
logger
.
warning
(
'update search space is not supported.'
)
self
.
logger
.
warning
(
'update search space is not supported.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""receive_trial_result
"""receive_trial_result
Parameters
Parameters
...
@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
...
@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict
[
key
]
=
value
converted_dict
[
key
]
=
value
return
converted_dict
return
converted_dict
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""generate one instance of hyperparameters
"""generate one instance of hyperparameters
Parameters
Parameters
...
@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
...
@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self
.
total_data
[
parameter_id
]
=
challenger
self
.
total_data
[
parameter_id
]
=
challenger
return
self
.
convert_loguniform_categorical
(
challenger
.
get_dictionary
())
return
self
.
convert_loguniform_categorical
(
challenger
.
get_dictionary
())
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""generate mutiple instances of hyperparameters
"""generate mutiple instances of hyperparameters
Parameters
Parameters
...
...
src/sdk/pynni/nni/tuner.py
View file @
1500458a
...
@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
...
@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class
Tuner
(
Recoverable
):
class
Tuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
parameter_id: int
"""
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
User code must override either this function or 'generate_parameters()'.
...
@@ -49,13 +49,13 @@ class Tuner(Recoverable):
...
@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for
parameter_id
in
parameter_id_list
:
for
parameter_id
in
parameter_id_list
:
try
:
try
:
_logger
.
debug
(
"generating param for {}"
.
format
(
parameter_id
))
_logger
.
debug
(
"generating param for {}"
.
format
(
parameter_id
))
res
=
self
.
generate_parameters
(
parameter_id
)
res
=
self
.
generate_parameters
(
parameter_id
,
**
kwargs
)
except
nni
.
NoMoreTrialError
:
except
nni
.
NoMoreTrialError
:
return
result
return
result
result
.
append
(
res
)
result
.
append
(
res
)
return
result
return
result
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial reports its final result. Must override.
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameter_id: int
parameters: object created by 'generate_parameters()'
parameters: object created by 'generate_parameters()'
...
@@ -63,7 +63,7 @@ class Tuner(Recoverable):
...
@@ -63,7 +63,7 @@ class Tuner(Recoverable):
"""
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameter_id: int
parameters: object created by user
parameters: object created by user
...
@@ -71,7 +71,7 @@ class Tuner(Recoverable):
...
@@ -71,7 +71,7 @@ class Tuner(Recoverable):
"""
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
):
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
success: True if the trial successfully completed; False if failed or terminated
...
...
src/sdk/pynni/tests/test_multi_phase_tuner.py
deleted
100644 → 0
View file @
93dd76ba
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
import
random
from
io
import
BytesIO
import
nni
import
nni.protocol
from
nni.protocol
import
CommandType
,
send
,
receive
from
nni.multi_phase.multi_phase_tuner
import
MultiPhaseTuner
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
from
unittest
import
TestCase
,
main
class
NaiveMultiPhaseTuner
(
MultiPhaseTuner
):
'''
supports only choices
'''
def
__init__
(
self
):
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters
=
{}
if
self
.
search_space
is
None
:
raise
AssertionError
(
'Search space not specified'
)
for
k
in
self
.
search_space
:
param
=
self
.
search_space
[
k
]
if
not
param
[
'_type'
]
==
'choice'
:
raise
ValueError
(
'Only choice type is supported'
)
param_values
=
param
[
'_value'
]
generated_parameters
[
k
]
=
param_values
[
random
.
randint
(
0
,
len
(
param_values
)
-
1
)]
logging
.
getLogger
(
__name__
).
debug
(
generated_parameters
)
return
generated_parameters
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
logging
.
getLogger
(
__name__
).
debug
(
'receive_trial_result: {},{},{},{}'
.
format
(
parameter_id
,
parameters
,
value
,
trial_job_id
))
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
pass
def
update_search_space
(
self
,
search_space
):
self
.
search_space
=
search_space
_in_buf
=
BytesIO
()
_out_buf
=
BytesIO
()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_out_file
=
_in_buf
nni
.
protocol
.
_in_file
=
_out_buf
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_in_file
=
_in_buf
nni
.
protocol
.
_out_file
=
_out_buf
def
_test_tuner
():
_reverse_io
()
# now we are sending to Tuner's incoming stream
send
(
CommandType
.
UpdateSearchSpace
,
"{
\"
learning_rate
\"
: {
\"
_value
\"
: [0.0001, 0.001, 0.002, 0.005, 0.01],
\"
_type
\"
:
\"
choice
\"
},
\"
optimizer
\"
: {
\"
_value
\"
: [
\"
Adam
\"
,
\"
SGD
\"
],
\"
_type
\"
:
\"
choice
\"
}}"
)
send
(
CommandType
.
RequestTrialJobs
,
'2'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}'
)
send
(
CommandType
.
AddCustomizedTrialJob
,
'{"param":-1}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}'
)
send
(
CommandType
.
RequestTrialJobs
,
'1'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"abc"}'
)
_restore_io
()
tuner
=
NaiveMultiPhaseTuner
()
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
)
dispatcher
.
run
()
_reverse_io
()
# now we are receiving from Tuner's outgoing stream
command
,
data
=
receive
()
# this one is customized
print
(
command
,
data
)
class
MultiPhaseTestCase
(
TestCase
):
def
test_tuner
(
self
):
_test_tuner
()
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
src/sdk/pynni/tests/test_tuner.py
View file @
1500458a
...
@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
...
@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self
.
trial_results
=
[
]
self
.
trial_results
=
[
]
self
.
search_space
=
None
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
# report Tuner's internal states to generated parameters,
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
# so we don't need to pause the main loop
self
.
param
+=
2
self
.
param
+=
2
...
@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
...
@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space'
:
self
.
search_space
'search_space'
:
self
.
search_space
}
}
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
reward
=
extract_scalar_reward
(
value
)
reward
=
extract_scalar_reward
(
value
)
self
.
trial_results
.
append
((
parameter_id
,
parameters
[
'param'
],
reward
,
False
))
self
.
trial_results
.
append
((
parameter_id
,
parameters
[
'param'
],
reward
,
False
))
...
@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
...
@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command
,
data
=
receive
()
# this one is customized
command
,
data
=
receive
()
# this one is customized
data
=
json
.
loads
(
data
)
data
=
json
.
loads
(
data
)
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
self
.
assertEqual
(
data
,
{
self
.
assertEqual
(
data
[
'parameter_id'
],
2
)
'parameter_id'
:
2
,
self
.
assertEqual
(
data
[
'parameter_source'
],
'customized'
)
'parameter_source'
:
'customized'
,
self
.
assertEqual
(
data
[
'parameters'
],
{
'param'
:
-
1
})
'parameters'
:
{
'param'
:
-
1
}
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
...
...
src/webui/src/components/Modal/Compare.tsx
0 → 100644
View file @
1500458a
import
*
as
React
from
'
react
'
;
import
{
Row
,
Modal
}
from
'
antd
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
import
'
../../static/style/compare.scss
'
;
import
{
TableObj
,
Intermedia
,
TooltipForIntermediate
}
from
'
src/static/interface
'
;
// the modal of trial compare
interface
CompareProps
{
compareRows
:
Array
<
TableObj
>
;
visible
:
boolean
;
cancelFunc
:
()
=>
void
;
}
class
Compare
extends
React
.
Component
<
CompareProps
,
{}
>
{
public
_isCompareMount
:
boolean
;
constructor
(
props
:
CompareProps
)
{
super
(
props
);
}
intermediate
=
()
=>
{
const
{
compareRows
}
=
this
.
props
;
const
trialIntermediate
:
Array
<
Intermedia
>
=
[];
const
idsList
:
Array
<
string
>
=
[];
Object
.
keys
(
compareRows
).
map
(
item
=>
{
const
temp
=
compareRows
[
item
];
trialIntermediate
.
push
({
name
:
temp
.
id
,
data
:
temp
.
description
.
intermediate
,
type
:
'
line
'
,
hyperPara
:
temp
.
description
.
parameters
});
idsList
.
push
(
temp
.
id
);
});
// find max intermediate number
trialIntermediate
.
sort
((
a
,
b
)
=>
{
return
(
b
.
data
.
length
-
a
.
data
.
length
);
});
const
legend
:
Array
<
string
>
=
[];
// max length
const
length
=
trialIntermediate
[
0
]
!==
undefined
?
trialIntermediate
[
0
].
data
.
length
:
0
;
const
xAxis
:
Array
<
number
>
=
[];
Object
.
keys
(
trialIntermediate
).
map
(
item
=>
{
const
temp
=
trialIntermediate
[
item
];
legend
.
push
(
temp
.
name
);
});
for
(
let
i
=
1
;
i
<=
length
;
i
++
)
{
xAxis
.
push
(
i
);
}
const
option
=
{
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForIntermediate
)
{
if
(
data
.
dataIndex
<
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
formatter
:
function
(
data
:
TooltipForIntermediate
)
{
const
trialId
=
data
.
seriesName
;
let
obj
=
{};
const
temp
=
trialIntermediate
.
find
(
key
=>
key
.
name
===
trialId
);
if
(
temp
!==
undefined
)
{
obj
=
temp
.
hyperPara
;
}
return
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial ID:
'
+
trialId
+
'
</div>
'
+
'
<div>Intermediate:
'
+
data
.
data
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
obj
,
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
}
},
grid
:
{
left
:
'
5%
'
,
top
:
40
,
containLabel
:
true
},
legend
:
{
data
:
idsList
},
xAxis
:
{
type
:
'
category
'
,
name
:
'
Step
'
,
boundaryGap
:
false
,
data
:
xAxis
},
yAxis
:
{
type
:
'
value
'
,
name
:
'
metric
'
},
series
:
trialIntermediate
};
return
(
<
ReactEcharts
option
=
{
option
}
style
=
{
{
width
:
'
100%
'
,
height
:
418
,
margin
:
'
0 auto
'
}
}
notMerge
=
{
true
}
// update now
/>
);
}
// render table column ---
initColumn
=
()
=>
{
const
{
compareRows
}
=
this
.
props
;
const
idList
:
Array
<
string
>
=
[];
const
durationList
:
Array
<
number
>
=
[];
const
parameterList
:
Array
<
object
>
=
[];
let
parameterKeys
:
Array
<
string
>
=
[];
if
(
compareRows
.
length
!==
0
)
{
parameterKeys
=
Object
.
keys
(
compareRows
[
0
].
description
.
parameters
);
}
Object
.
keys
(
compareRows
).
map
(
item
=>
{
const
temp
=
compareRows
[
item
];
idList
.
push
(
temp
.
id
);
durationList
.
push
(
temp
.
duration
);
parameterList
.
push
(
temp
.
description
.
parameters
);
});
return
(
<
table
className
=
"compare"
>
<
tbody
>
<
tr
>
<
td
/>
{
Object
.
keys
(
idList
).
map
(
key
=>
{
return
(
<
td
className
=
"value"
key
=
{
key
}
>
{
idList
[
key
]
}
</
td
>
);
})
}
</
tr
>
<
tr
>
<
td
className
=
"column"
>
Default metric
</
td
>
{
Object
.
keys
(
compareRows
).
map
(
index
=>
{
const
temp
=
compareRows
[
index
];
return
(
<
td
className
=
"value"
key
=
{
index
}
>
<
IntermediateVal
record
=
{
temp
}
/>
</
td
>
);
})
}
</
tr
>
<
tr
>
<
td
className
=
"column"
>
duration
</
td
>
{
Object
.
keys
(
durationList
).
map
(
index
=>
{
return
(
<
td
className
=
"value"
key
=
{
index
}
>
{
durationList
[
index
]
}
</
td
>
);
})
}
</
tr
>
{
Object
.
keys
(
parameterKeys
).
map
(
index
=>
{
return
(
<
tr
key
=
{
index
}
>
<
td
className
=
"column"
key
=
{
index
}
>
{
parameterKeys
[
index
]
}
</
td
>
{
Object
.
keys
(
parameterList
).
map
(
key
=>
{
return
(
<
td
key
=
{
key
}
className
=
"value"
>
{
parameterList
[
key
][
parameterKeys
[
index
]]
}
</
td
>
);
})
}
</
tr
>
);
})
}
</
tbody
>
</
table
>
);
}
componentDidMount
()
{
this
.
_isCompareMount
=
true
;
}
componentWillUnmount
()
{
this
.
_isCompareMount
=
false
;
}
render
()
{
const
{
visible
,
cancelFunc
}
=
this
.
props
;
return
(
<
Modal
title
=
"Compare trials"
visible
=
{
visible
}
onCancel
=
{
cancelFunc
}
footer
=
{
null
}
destroyOnClose
=
{
true
}
maskClosable
=
{
false
}
width
=
"90%"
>
<
Row
>
{
this
.
intermediate
()
}
</
Row
>
<
Row
>
{
this
.
initColumn
()
}
</
Row
>
</
Modal
>
);
}
}
export
default
Compare
;
src/webui/src/components/TrialsDetail.tsx
View file @
1500458a
...
@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{
/* trial table list */
}
{
/* trial table list */
}
<
Title1
text
=
"Trial jobs"
icon
=
"6.png"
/>
<
Title1
text
=
"Trial jobs"
icon
=
"6.png"
/>
<
Row
className
=
"allList"
>
<
Row
className
=
"allList"
>
<
Col
span
=
{
1
2
}
>
<
Col
span
=
{
1
0
}
>
<
span
>
Show
</
span
>
<
span
>
Show
</
span
>
<
Select
<
Select
className
=
"entry"
className
=
"entry"
...
@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
...
@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</
Select
>
</
Select
>
<
span
>
entries
</
span
>
<
span
>
entries
</
span
>
</
Col
>
</
Col
>
<
Col
span
=
{
12
}
className
=
"right"
>
<
Col
span
=
{
14
}
className
=
"right"
>
<
Row
>
<
Button
<
Col
span
=
{
12
}
>
type
=
"primary"
<
Button
className
=
"tableButton editStyle"
type
=
"primary"
onClick
=
{
this
.
tableList
?
this
.
tableList
.
addColumn
:
this
.
test
}
className
=
"tableButton editStyle"
>
onClick
=
{
this
.
tableList
?
this
.
tableList
.
addColumn
:
this
.
test
}
Add column
>
</
Button
>
Add column
<
Button
</
Button
>
type
=
"primary"
</
Col
>
className
=
"tableButton editStyle mediateBtn"
<
Col
span
=
{
12
}
>
// use child-component tableList's function, the function is in child-component.
<
Input
onClick
=
{
this
.
tableList
?
this
.
tableList
.
compareBtn
:
this
.
test
}
type
=
"text"
>
placeholder
=
"Search by id, trial No. or status"
Compare
onChange
=
{
this
.
searchTrial
}
</
Button
>
style
=
{
{
width
:
230
,
marginLeft
:
6
}
}
<
Input
/>
type
=
"text"
</
Col
>
placeholder
=
"Search by id, trial No. or status"
</
Row
>
onChange
=
{
this
.
searchTrial
}
style
=
{
{
width
:
230
,
marginLeft
:
6
}
}
/>
</
Col
>
</
Col
>
</
Row
>
</
Row
>
<
TableList
<
TableList
...
...
src/webui/src/components/trial-detail/Intermeidate.tsx
View file @
1500458a
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
{
Row
,
Col
,
Button
,
Switch
}
from
'
antd
'
;
import
{
Row
,
Col
,
Button
,
Switch
}
from
'
antd
'
;
import
{
TooltipForIntermediate
,
TableObj
}
from
'
../../static/interface
'
;
import
{
TooltipForIntermediate
,
TableObj
,
Intermedia
}
from
'
../../static/interface
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/title
'
);
require
(
'
echarts/lib/component/title
'
);
interface
Intermedia
{
name
:
string
;
// id
type
:
string
;
data
:
Array
<
number
|
object
>
;
// intermediate data
hyperPara
:
object
;
// each trial hyperpara value
}
interface
IntermediateState
{
interface
IntermediateState
{
detailSource
:
Array
<
TableObj
>
;
detailSource
:
Array
<
TableObj
>
;
interSource
:
object
;
interSource
:
object
;
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment