Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6febf29
"src/targets/cpu/erf.cpp" did not exist on "e1ef1e173bb4ce5a7479178c627e49a8f51884ce"
Commit
d6febf29
authored
Jun 25, 2019
by
suiguoxin
Browse files
Merge branch 'master' of
git://github.com/microsoft/nni
parents
77c95479
c2179921
Changes
90
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
658 additions
and
531 deletions
+658
-531
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+42
-26
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+23
-7
src/sdk/pynni/nni/multi_phase/__init__.py
src/sdk/pynni/nni/multi_phase/__init__.py
+0
-0
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+0
-198
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
+0
-106
src/sdk/pynni/nni/nas_utils.py
src/sdk/pynni/nni/nas_utils.py
+160
-0
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
.../pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
+2
-2
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
+3
-3
src/sdk/pynni/nni/smartparam.py
src/sdk/pynni/nni/smartparam.py
+35
-8
src/sdk/pynni/nni/tuner.py
src/sdk/pynni/nni/tuner.py
+6
-6
src/sdk/pynni/tests/test_multi_phase_tuner.py
src/sdk/pynni/tests/test_multi_phase_tuner.py
+0
-110
src/sdk/pynni/tests/test_smartparam.py
src/sdk/pynni/tests/test_smartparam.py
+16
-1
src/sdk/pynni/tests/test_tuner.py
src/sdk/pynni/tests/test_tuner.py
+5
-7
src/webui/src/components/Modal/Compare.tsx
src/webui/src/components/Modal/Compare.tsx
+204
-0
src/webui/src/components/Overview.tsx
src/webui/src/components/Overview.tsx
+4
-2
src/webui/src/components/SlideBar.tsx
src/webui/src/components/SlideBar.tsx
+5
-18
src/webui/src/components/TrialsDetail.tsx
src/webui/src/components/TrialsDetail.tsx
+23
-21
src/webui/src/components/trial-detail/Intermeidate.tsx
src/webui/src/components/trial-detail/Intermeidate.tsx
+1
-8
src/webui/src/components/trial-detail/Para.tsx
src/webui/src/components/trial-detail/Para.tsx
+11
-3
src/webui/src/components/trial-detail/TableList.tsx
src/webui/src/components/trial-detail/TableList.tsx
+118
-5
No files found.
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
d6febf29
...
...
@@ -20,15 +20,15 @@
import
copy
import
logging
import
numpy
as
np
import
os
import
random
import
statistics
import
sys
import
warnings
from
enum
import
Enum
,
unique
from
multiprocessing.dummy
import
Pool
as
ThreadPool
import
numpy
as
np
import
nni.metis_tuner.lib_constraint_summation
as
lib_constraint_summation
import
nni.metis_tuner.lib_data
as
lib_data
import
nni.metis_tuner.Regression_GMM.CreateModel
as
gmm_create_model
...
...
@@ -42,8 +42,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
logger
=
logging
.
getLogger
(
"Metis_Tuner_AutoML"
)
NONE_TYPE
=
''
CONSTRAINT_LOWERBOUND
=
None
CONSTRAINT_UPPERBOUND
=
None
...
...
@@ -93,7 +91,7 @@ class MetisTuner(Tuner):
self
.
space
=
None
self
.
no_resampling
=
no_resampling
self
.
no_candidates
=
no_candidates
self
.
optimize_mode
=
optimize_mode
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
key_order
=
[]
self
.
cold_start_num
=
cold_start_num
self
.
selection_num_starting_points
=
selection_num_starting_points
...
...
@@ -174,7 +172,7 @@ class MetisTuner(Tuner):
return
output
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
...
...
@@ -205,7 +203,7 @@ class MetisTuner(Tuner):
return
results
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Tuner receive result from trial.
Parameters
...
...
@@ -254,6 +252,9 @@ class MetisTuner(Tuner):
threshold_samplessize_resampling
=
50
,
no_candidates
=
False
,
minimize_starting_points
=
None
,
minimize_constraints_fun
=
None
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
next_candidate
=
None
candidates
=
[]
samples_size_all
=
sum
([
len
(
i
)
for
i
in
samples_y
])
...
...
@@ -271,13 +272,12 @@ class MetisTuner(Tuner):
minimize_constraints_fun
=
minimize_constraints_fun
)
if
not
lm_current
:
return
None
if
no_candidates
is
False
:
candidates
.
append
({
'hyperparameter'
:
lm_current
[
'hyperparameter'
],
logger
.
info
({
'hyperparameter'
:
lm_current
[
'hyperparameter'
],
'expected_mu'
:
lm_current
[
'expected_mu'
],
'expected_sigma'
:
lm_current
[
'expected_sigma'
],
'reason'
:
"exploitation_gp"
})
if
no_candidates
is
False
:
# ===== STEP 2: Get recommended configurations for exploration =====
results_exploration
=
gp_selection
.
selection
(
"lc"
,
...
...
@@ -290,34 +290,48 @@ class MetisTuner(Tuner):
if
results_exploration
is
not
None
:
if
_num_past_samples
(
results_exploration
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
candidate
s
.
append
(
{
'hyperparameter'
:
results_exploration
[
'hyperparameter'
],
temp_
candidate
=
{
'hyperparameter'
:
results_exploration
[
'hyperparameter'
],
'expected_mu'
:
results_exploration
[
'expected_mu'
],
'expected_sigma'
:
results_exploration
[
'expected_sigma'
],
'reason'
:
"exploration"
})
'reason'
:
"exploration"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: 1 exploration candidate selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
logger
.
info
(
"DEBUG: No suitable exploration candidates were"
)
# ===== STEP 3: Get recommended configurations for exploitation =====
if
samples_size_all
>=
threshold_samplessize_exploitation
:
print
(
"Getting candidates for exploitation...
\n
"
)
logger
.
info
(
"Getting candidates for exploitation...
\n
"
)
try
:
gmm
=
gmm_create_model
.
create_model
(
samples_x
,
samples_y_aggregation
)
results_exploitation
=
gmm_selection
.
selection
(
x_bounds
,
x_types
,
if
(
"discrete_int"
in
x_types
)
or
(
"range_int"
in
x_types
):
results_exploitation
=
gmm_selection
.
selection
(
x_bounds
,
x_types
,
gmm
[
'clusteringmodel_good'
],
gmm
[
'clusteringmodel_bad'
],
minimize_starting_points
,
minimize_constraints_fun
=
minimize_constraints_fun
)
else
:
# If all parameters are of "range_continuous", let's use GMM to generate random starting points
results_exploitation
=
gmm_selection
.
selection_r
(
x_bounds
,
x_types
,
gmm
[
'clusteringmodel_good'
],
gmm
[
'clusteringmodel_bad'
],
num_starting_points
=
self
.
selection_num_starting_points
,
minimize_constraints_fun
=
minimize_constraints_fun
)
if
results_exploitation
is
not
None
:
if
_num_past_samples
(
results_exploitation
[
'hyperparameter'
],
samples_x
,
samples_y
)
==
0
:
candidates
.
append
({
'hyperparameter'
:
results_exploitation
[
'hyperparameter'
],
\
'expected_mu'
:
results_exploitation
[
'expected_mu'
],
\
'expected_sigma'
:
results_exploitation
[
'expected_sigma'
],
\
'reason'
:
"exploitation_gmm"
})
temp_expected_mu
,
temp_expected_sigma
=
gp_prediction
.
predict
(
results_exploitation
[
'hyperparameter'
],
gp_model
[
'model'
])
temp_candidate
=
{
'hyperparameter'
:
results_exploitation
[
'hyperparameter'
],
'expected_mu'
:
temp_expected_mu
,
'expected_sigma'
:
temp_expected_sigma
,
'reason'
:
"exploitation_gmm"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: 1 exploitation_gmm candidate selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
logger
.
info
(
"DEBUG: No suitable exploitation_gmm candidates were found
\n
"
)
...
...
@@ -338,11 +352,13 @@ class MetisTuner(Tuner):
if
results_outliers
is
not
None
:
for
results_outlier
in
results_outliers
:
if
_num_past_samples
(
samples_x
[
results_outlier
[
'samples_idx'
]],
samples_x
,
samples_y
)
<
max_resampling_per_x
:
candidate
s
.
append
(
{
'hyperparameter'
:
samples_x
[
results_outlier
[
'samples_idx'
]],
\
temp_
candidate
=
{
'hyperparameter'
:
samples_x
[
results_outlier
[
'samples_idx'
]],
\
'expected_mu'
:
results_outlier
[
'expected_mu'
],
\
'expected_sigma'
:
results_outlier
[
'expected_sigma'
],
\
'reason'
:
"resampling"
})
'reason'
:
"resampling"
}
candidates
.
append
(
temp_candidate
)
logger
.
info
(
"DEBUG: %d re-sampling candidates selected
\n
"
)
logger
.
info
(
temp_candidate
)
else
:
logger
.
info
(
"DEBUG: No suitable resampling candidates were found
\n
"
)
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
d6febf29
...
...
@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
import
logging
from
collections
import
defaultdict
import
json_tricks
...
...
@@ -26,7 +25,7 @@ import json_tricks
from
.protocol
import
CommandType
,
send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
):
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
class
MsgDispatcher
(
MsgDispatcherBase
):
...
...
@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
else
:
pass
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
=
data
[
'trial_job_id'
])
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
...
...
@@ -160,7 +170,13 @@ class MsgDispatcher(MsgDispatcherBase):
id_
=
data
[
'parameter_id'
]
value
=
data
[
'value'
]
if
id_
in
_customized_parameter_ids
:
if
multi_phase_enabled
():
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
else
:
if
multi_phase_enabled
():
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
trial_job_id
=
data
[
'trial_job_id'
])
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
)
...
...
src/sdk/pynni/nni/multi_phase/__init__.py
deleted
100644 → 0
View file @
77c95479
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
deleted
100644 → 0
View file @
77c95479
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
collections
import
defaultdict
import
json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.assessor
import
AssessResult
_logger
=
logging
.
getLogger
(
__name__
)
# Assessor global variables
_trial_history
=
defaultdict
(
dict
)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials
=
set
()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def
_sort_history
(
history
):
ret
=
[
]
for
i
,
_
in
enumerate
(
history
):
if
i
in
history
:
ret
.
append
(
history
[
i
])
else
:
break
return
ret
# Tuner global variables
_next_parameter_id
=
0
_trial_params
=
{}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids
=
set
()
def
_create_parameter_id
():
global
_next_parameter_id
# pylint: disable=global-statement
_next_parameter_id
+=
1
return
_next_parameter_id
-
1
def
_pack_parameter
(
parameter_id
,
params
,
customized
=
False
,
trial_job_id
=
None
,
parameter_index
=
None
):
_trial_params
[
parameter_id
]
=
params
ret
=
{
'parameter_id'
:
parameter_id
,
'parameter_source'
:
'customized'
if
customized
else
'algorithm'
,
'parameters'
:
params
}
if
trial_job_id
is
not
None
:
ret
[
'trial_job_id'
]
=
trial_job_id
if
parameter_index
is
not
None
:
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
json_tricks
.
dumps
(
ret
)
class
MultiPhaseMsgDispatcher
(
MsgDispatcherBase
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
super
(
MultiPhaseMsgDispatcher
,
self
).
__init__
()
self
.
tuner
=
tuner
self
.
assessor
=
assessor
if
assessor
is
None
:
_logger
.
debug
(
'Assessor is not configured'
)
def
load_checkpoint
(
self
):
self
.
tuner
.
load_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
load_checkpoint
()
def
save_checkpoint
(
self
):
self
.
tuner
.
save_checkpoint
()
if
self
.
assessor
is
not
None
:
self
.
assessor
.
save_checkpoint
()
def
handle_initialize
(
self
,
data
):
'''
data is search space
'''
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
return
True
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
ids
=
[
_create_parameter_id
()
for
_
in
range
(
data
)]
params_list
=
self
.
tuner
.
generate_multiple_parameters
(
ids
)
assert
len
(
ids
)
==
len
(
params_list
)
for
i
,
_
in
enumerate
(
ids
):
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
ids
[
i
],
params_list
[
i
]))
return
True
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
return
True
def
handle_import_data
(
self
,
data
):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
return
True
def
handle_report_metric_data
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
if
data
[
'type'
]
==
'FINAL'
:
id_
=
data
[
'parameter_id'
]
if
id_
in
_customized_parameter_ids
:
self
.
tuner
.
receive_customized_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
else
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
data
[
'value'
],
trial_job_id
)
elif
data
[
'type'
]
==
'PERIODICAL'
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
else
:
pass
elif
data
[
'type'
]
==
'REQUEST_PARAMETER'
:
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
param_id
=
_create_parameter_id
()
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
)
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
return
True
def
handle_trial_end
(
self
,
data
):
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
if
trial_job_id
in
_trial_history
:
_trial_history
.
pop
(
trial_job_id
)
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
,
trial_job_id
)
return
True
def
handle_import_data
(
self
,
data
):
pass
def
_handle_intermediate_metric_data
(
self
,
data
):
if
data
[
'type'
]
!=
'PERIODICAL'
:
return
True
if
self
.
assessor
is
None
:
return
True
trial_job_id
=
data
[
'trial_job_id'
]
if
trial_job_id
in
_ended_trials
:
return
True
history
=
_trial_history
[
trial_job_id
]
history
[
data
[
'sequence'
]]
=
data
[
'value'
]
ordered_history
=
_sort_history
(
history
)
if
len
(
ordered_history
)
<
data
[
'sequence'
]:
# no user-visible update since last time
return
True
try
:
result
=
self
.
assessor
.
assess_trial
(
trial_job_id
,
ordered_history
)
except
Exception
as
e
:
_logger
.
exception
(
'Assessor error'
)
if
isinstance
(
result
,
bool
):
result
=
AssessResult
.
Good
if
result
else
AssessResult
.
Bad
elif
not
isinstance
(
result
,
AssessResult
):
msg
=
'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise
RuntimeError
(
msg
%
type
(
result
))
if
result
is
AssessResult
.
Bad
:
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dumps
(
trial_job_id
))
else
:
_logger
.
debug
(
'GOOD'
)
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
deleted
100644 → 0
View file @
77c95479
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
from
nni.recoverable
import
Recoverable
_logger
=
logging
.
getLogger
(
__name__
)
class
MultiPhaseTuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return
[
self
.
generate_parameters
(
parameter_id
)
for
parameter_id
in
parameter_id_list
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
,
trial_job_id
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def
update_search_space
(
self
,
search_space
):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
load_checkpoint
(
self
):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Load checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
save_checkpoint
(
self
):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
%
checkpoin_path
)
def
_on_exit
(
self
):
pass
def
_on_error
(
self
):
pass
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/nas_utils.py
0 → 100644
View file @
d6febf29
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
from
.
import
trial
def
classic_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
):
'''Execute the chosen function and inputs directly.
In this mode, the trial code is only running the chosen subgraph (i.e., the chosen ops and inputs),
without touching the full model graph.'''
if
trial
.
_params
is
None
:
trial
.
get_next_parameter
()
mutable_block
=
trial
.
get_current_parameter
(
mutable_id
)
chosen_layer
=
mutable_block
[
mutable_layer_id
][
"chosen_layer"
]
chosen_inputs
=
mutable_block
[
mutable_layer_id
][
"chosen_inputs"
]
real_chosen_inputs
=
[
optional_inputs
[
input_name
]
for
input_name
in
chosen_inputs
]
layer_out
=
funcs
[
chosen_layer
](
[
fixed_inputs
,
real_chosen_inputs
],
**
funcs_args
[
chosen_layer
])
return
layer_out
def
enas_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
,
tf
):
'''For enas mode, we build the full model graph in trial but only run a subgraph。
This is implemented by masking inputs and branching ops.
Specifically, based on the received subgraph (through nni.get_next_parameter),
it can be known which inputs should be masked and which op should be executed.'''
name_prefix
=
"{}_{}"
.
format
(
mutable_id
,
mutable_layer_id
)
# store namespace
if
'name_space'
not
in
globals
():
global
name_space
name_space
=
dict
()
name_space
[
mutable_id
]
=
True
name_space
[
name_prefix
]
=
dict
()
name_space
[
name_prefix
][
'funcs'
]
=
list
(
funcs
)
name_space
[
name_prefix
][
'optional_inputs'
]
=
list
(
optional_inputs
)
# create tensorflow variables as 1/0 signals used to form subgraph
if
'tf_variables'
not
in
globals
():
global
tf_variables
tf_variables
=
dict
()
name_for_optional_inputs
=
name_prefix
+
'_optional_inputs'
name_for_funcs
=
name_prefix
+
'_funcs'
tf_variables
[
name_prefix
]
=
dict
()
tf_variables
[
name_prefix
][
'optional_inputs'
]
=
tf
.
get_variable
(
name_for_optional_inputs
,
[
len
(
optional_inputs
)],
dtype
=
tf
.
bool
,
trainable
=
False
)
tf_variables
[
name_prefix
][
'funcs'
]
=
tf
.
get_variable
(
name_for_funcs
,
[],
dtype
=
tf
.
int64
,
trainable
=
False
)
# get real values using their variable names
real_optional_inputs_value
=
[
optional_inputs
[
name
]
for
name
in
name_space
[
name_prefix
][
'optional_inputs'
]]
real_func_value
=
[
funcs
[
name
]
for
name
in
name_space
[
name_prefix
][
'funcs'
]]
real_funcs_args
=
[
funcs_args
[
name
]
for
name
in
name_space
[
name_prefix
][
'funcs'
]]
# build tensorflow graph of geting chosen inputs by masking
real_chosen_inputs
=
tf
.
boolean_mask
(
real_optional_inputs_value
,
tf_variables
[
name_prefix
][
'optional_inputs'
])
# build tensorflow graph of different branches by using tf.case
branches
=
dict
()
for
func_id
in
range
(
len
(
funcs
)):
func_output
=
real_func_value
[
func_id
](
[
fixed_inputs
,
real_chosen_inputs
],
**
real_funcs_args
[
func_id
])
branches
[
tf
.
equal
(
tf_variables
[
name_prefix
][
'funcs'
],
func_id
)]
=
lambda
:
func_output
layer_out
=
tf
.
case
(
branches
,
exclusive
=
True
,
default
=
lambda
:
func_output
)
return
layer_out
def
oneshot_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
,
tf
):
'''Similar to enas mode, oneshot mode also builds the full model graph.
The difference is that oneshot mode does not receive subgraph.
Instead, it uses dropout to randomly dropout inputs and ops.'''
# NNI requires to get_next_parameter before report a result. But the parameter will not be used in this mode
if
trial
.
_params
is
None
:
trial
.
get_next_parameter
()
optional_inputs
=
list
(
optional_inputs
.
values
())
inputs_num
=
len
(
optional_inputs
)
# Calculate dropout rate according to the formular r^(1/k), where r is a hyper-parameter and k is the number of inputs
if
inputs_num
>
0
:
rate
=
0.01
**
(
1
/
inputs_num
)
noise_shape
=
[
inputs_num
]
+
[
1
]
*
len
(
optional_inputs
[
0
].
get_shape
())
optional_inputs
=
tf
.
nn
.
dropout
(
optional_inputs
,
rate
=
rate
,
noise_shape
=
noise_shape
)
optional_inputs
=
[
optional_inputs
[
idx
]
for
idx
in
range
(
inputs_num
)]
layer_outs
=
[
func
([
fixed_inputs
,
optional_inputs
],
**
funcs_args
[
func_name
])
for
func_name
,
func
in
funcs
.
items
()]
layer_out
=
tf
.
add_n
(
layer_outs
)
return
layer_out
def
reload_tensorflow_variables
(
session
,
tf
=
None
):
'''In Enas mode, this function reload every signal varaible created in `enas_mode` function so
the whole tensorflow graph will be changed into certain subgraph recerived from Tuner.
---------------
session: the tensorflow session created by users
tf: tensorflow module
'''
subgraph_from_tuner
=
trial
.
get_next_parameter
()
for
mutable_id
,
mutable_block
in
subgraph_from_tuner
.
items
():
if
mutable_id
not
in
name_space
:
continue
for
mutable_layer_id
,
mutable_layer
in
mutable_block
.
items
():
name_prefix
=
"{}_{}"
.
format
(
mutable_id
,
mutable_layer_id
)
# extract layer information from the subgraph sampled by tuner
chosen_layer
=
name_space
[
name_prefix
][
'funcs'
].
index
(
mutable_layer
[
"chosen_layer"
])
chosen_inputs
=
[
1
if
inp
in
mutable_layer
[
"chosen_inputs"
]
else
0
for
inp
in
name_space
[
name_prefix
][
'optional_inputs'
]]
# load these information into pre-defined tensorflow variables
tf_variables
[
name_prefix
][
'funcs'
].
load
(
chosen_layer
,
session
)
tf_variables
[
name_prefix
][
'optional_inputs'
].
load
(
chosen_inputs
,
session
)
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
View file @
d6febf29
...
...
@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
"""
self
.
search_space
=
search_space
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""
Returns a set of trial neural architecture, as a serializable object.
...
...
@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return
json_out
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
""" Record an observation of the objective function.
Parameters
...
...
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
View file @
d6febf29
...
...
@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else
:
self
.
logger
.
warning
(
'update search space is not supported.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""receive_trial_result
Parameters
...
...
@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict
[
key
]
=
value
return
converted_dict
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""generate one instance of hyperparameters
Parameters
...
...
@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self
.
total_data
[
parameter_id
]
=
challenger
return
self
.
convert_loguniform_categorical
(
challenger
.
get_dictionary
())
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""generate mutiple instances of hyperparameters
Parameters
...
...
src/sdk/pynni/nni/smartparam.py
View file @
d6febf29
...
...
@@ -23,6 +23,7 @@ import random
from
.env_vars
import
trial_env_vars
from
.
import
trial
from
.nas_utils
import
classic_mode
,
enas_mode
,
oneshot_mode
__all__
=
[
...
...
@@ -124,7 +125,9 @@ else:
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
):
optional_input_size
,
mode
=
'classic_mode'
,
tf
=
None
):
'''execute the chosen function and inputs.
Below is an example of chosen function and inputs:
{
...
...
@@ -144,14 +147,38 @@ else:
fixed_inputs:
optional_inputs: dict of optional inputs
optional_input_size: number of candidate inputs to be chosen
tf: tensorflow module
'''
mutable_block
=
_get_param
(
mutable_id
)
chosen_layer
=
mutable_block
[
mutable_layer_id
][
"chosen_layer"
]
chosen_inputs
=
mutable_block
[
mutable_layer_id
][
"chosen_inputs"
]
real_chosen_inputs
=
[
optional_inputs
[
input_name
]
for
input_name
in
chosen_inputs
]
layer_out
=
funcs
[
chosen_layer
]([
fixed_inputs
,
real_chosen_inputs
],
**
funcs_args
[
chosen_layer
])
return
layer_out
if
mode
==
'classic_mode'
:
return
classic_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
)
elif
mode
==
'enas_mode'
:
assert
tf
is
not
None
,
'Internal Error: Tensorflow should not be None in enas_mode'
return
enas_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
,
tf
)
elif
mode
==
'oneshot_mode'
:
assert
tf
is
not
None
,
'Internal Error: Tensorflow should not be None in oneshot_mode'
return
oneshot_mode
(
mutable_id
,
mutable_layer_id
,
funcs
,
funcs_args
,
fixed_inputs
,
optional_inputs
,
optional_input_size
,
tf
)
else
:
raise
RuntimeError
(
'Unrecognized mode: %s'
%
mode
)
def
_get_param
(
key
):
if
trial
.
_params
is
None
:
...
...
src/sdk/pynni/nni/tuner.py
View file @
d6febf29
...
...
@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class
Tuner
(
Recoverable
):
# pylint: disable=no-self-use,unused-argument
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
):
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
...
...
@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for
parameter_id
in
parameter_id_list
:
try
:
_logger
.
debug
(
"generating param for {}"
.
format
(
parameter_id
))
res
=
self
.
generate_parameters
(
parameter_id
)
res
=
self
.
generate_parameters
(
parameter_id
,
**
kwargs
)
except
nni
.
NoMoreTrialError
:
return
result
result
.
append
(
res
)
return
result
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameters: object created by 'generate_parameters()'
...
...
@@ -63,7 +63,7 @@ class Tuner(Recoverable):
"""
raise
NotImplementedError
(
'Tuner: receive_trial_result not implemented'
)
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameters: object created by user
...
...
@@ -71,7 +71,7 @@ class Tuner(Recoverable):
"""
_logger
.
info
(
'Customized trial job %s ignored by tuner'
,
parameter_id
)
def
trial_end
(
self
,
parameter_id
,
success
):
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
...
...
src/sdk/pynni/tests/test_multi_phase_tuner.py
deleted
100644 → 0
View file @
77c95479
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
logging
import
random
from
io
import
BytesIO
import
nni
import
nni.protocol
from
nni.protocol
import
CommandType
,
send
,
receive
from
nni.multi_phase.multi_phase_tuner
import
MultiPhaseTuner
from
nni.multi_phase.multi_phase_dispatcher
import
MultiPhaseMsgDispatcher
from
unittest
import
TestCase
,
main
class
NaiveMultiPhaseTuner
(
MultiPhaseTuner
):
'''
supports only choices
'''
def
__init__
(
self
):
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
,
trial_job_id
=
None
):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters
=
{}
if
self
.
search_space
is
None
:
raise
AssertionError
(
'Search space not specified'
)
for
k
in
self
.
search_space
:
param
=
self
.
search_space
[
k
]
if
not
param
[
'_type'
]
==
'choice'
:
raise
ValueError
(
'Only choice type is supported'
)
param_values
=
param
[
'_value'
]
generated_parameters
[
k
]
=
param_values
[
random
.
randint
(
0
,
len
(
param_values
)
-
1
)]
logging
.
getLogger
(
__name__
).
debug
(
generated_parameters
)
return
generated_parameters
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
logging
.
getLogger
(
__name__
).
debug
(
'receive_trial_result: {},{},{},{}'
.
format
(
parameter_id
,
parameters
,
value
,
trial_job_id
))
def
receive_customized_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
trial_job_id
):
pass
def
update_search_space
(
self
,
search_space
):
self
.
search_space
=
search_space
_in_buf
=
BytesIO
()
_out_buf
=
BytesIO
()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_out_file
=
_in_buf
nni
.
protocol
.
_in_file
=
_out_buf
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
nni
.
protocol
.
_in_file
=
_in_buf
nni
.
protocol
.
_out_file
=
_out_buf
def
_test_tuner
():
_reverse_io
()
# now we are sending to Tuner's incoming stream
send
(
CommandType
.
UpdateSearchSpace
,
"{
\"
learning_rate
\"
: {
\"
_value
\"
: [0.0001, 0.001, 0.002, 0.005, 0.01],
\"
_type
\"
:
\"
choice
\"
},
\"
optimizer
\"
: {
\"
_value
\"
: [
\"
Adam
\"
,
\"
SGD
\"
],
\"
_type
\"
:
\"
choice
\"
}}"
)
send
(
CommandType
.
RequestTrialJobs
,
'2'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}'
)
send
(
CommandType
.
AddCustomizedTrialJob
,
'{"param":-1}'
)
send
(
CommandType
.
ReportMetricData
,
'{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}'
)
send
(
CommandType
.
RequestTrialJobs
,
'1'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"abc"}'
)
_restore_io
()
tuner
=
NaiveMultiPhaseTuner
()
dispatcher
=
MultiPhaseMsgDispatcher
(
tuner
)
dispatcher
.
run
()
_reverse_io
()
# now we are receiving from Tuner's outgoing stream
command
,
data
=
receive
()
# this one is customized
print
(
command
,
data
)
class
MultiPhaseTestCase
(
TestCase
):
def
test_tuner
(
self
):
_test_tuner
()
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
src/sdk/pynni/tests/test_smartparam.py
View file @
d6febf29
...
...
@@ -38,7 +38,13 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice3/choice'
:
'[1, 2]'
,
'test_smartparam/choice4/choice'
:
'{"a", 2}'
,
'test_smartparam/func/function_choice'
:
'bar'
,
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
'test_smartparam/lambda_func/function_choice'
:
"lambda: 2*3"
,
'mutable_block_66'
:{
'mutable_layer_0'
:{
'chosen_layer'
:
'conv2D(size=5)'
,
'chosen_inputs'
:
[
'y'
]
}
}
}
nni
.
trial
.
_params
=
{
'parameter_id'
:
'test_trial'
,
'parameters'
:
params
}
...
...
@@ -61,6 +67,13 @@ class SmartParamTestCase(TestCase):
val
=
nni
.
function_choice
({
"lambda: 2*3"
:
lambda
:
2
*
3
,
"lambda: 3*4"
:
lambda
:
3
*
4
},
name
=
'lambda_func'
,
key
=
'test_smartparam/lambda_func/function_choice'
)
self
.
assertEqual
(
val
,
6
)
def
test_mutable_layer
(
self
):
layer_out
=
nni
.
mutable_layer
(
'mutable_block_66'
,
'mutable_layer_0'
,
{
'conv2D(size=3)'
:
conv2D
,
'conv2D(size=5)'
:
conv2D
},
{
'conv2D(size=3)'
:
{
'size'
:
3
},
'conv2D(size=5)'
:
{
'size'
:
5
}},
[
100
],
{
'x'
:
1
,
'y'
:
2
},
1
,
'classic_mode'
)
self
.
assertEqual
(
layer_out
,
[
100
,
2
,
5
])
def
foo
():
return
'foo'
...
...
@@ -68,6 +81,8 @@ def foo():
def
bar
():
return
'bar'
def
conv2D
(
inputs
,
size
=
3
):
return
inputs
[
0
]
+
inputs
[
1
]
+
[
size
]
if
__name__
==
'__main__'
:
main
()
src/sdk/pynni/tests/test_tuner.py
View file @
d6febf29
...
...
@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self
.
trial_results
=
[
]
self
.
search_space
=
None
def
generate_parameters
(
self
,
parameter_id
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self
.
param
+=
2
...
...
@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space'
:
self
.
search_space
}
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
reward
=
extract_scalar_reward
(
value
)
self
.
trial_results
.
append
((
parameter_id
,
parameters
[
'param'
],
reward
,
False
))
...
...
@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command
,
data
=
receive
()
# this one is customized
data
=
json
.
loads
(
data
)
self
.
assertIs
(
command
,
CommandType
.
NewTrialJob
)
self
.
assertEqual
(
data
,
{
'parameter_id'
:
2
,
'parameter_source'
:
'customized'
,
'parameters'
:
{
'param'
:
-
1
}
})
self
.
assertEqual
(
data
[
'parameter_id'
],
2
)
self
.
assertEqual
(
data
[
'parameter_source'
],
'customized'
)
self
.
assertEqual
(
data
[
'parameters'
],
{
'param'
:
-
1
})
self
.
_assert_params
(
3
,
6
,
[[
1
,
4
,
11
,
False
],
[
2
,
-
1
,
22
,
True
]],
{
'name'
:
'SS0'
})
...
...
src/webui/src/components/Modal/Compare.tsx
0 → 100644
View file @
d6febf29
import
*
as
React
from
'
react
'
;
import
{
Row
,
Modal
}
from
'
antd
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
import
'
../../static/style/compare.scss
'
;
import
{
TableObj
,
Intermedia
,
TooltipForIntermediate
}
from
'
src/static/interface
'
;
// the modal of trial compare
interface
CompareProps
{
compareRows
:
Array
<
TableObj
>
;
visible
:
boolean
;
cancelFunc
:
()
=>
void
;
}
class
Compare
extends
React
.
Component
<
CompareProps
,
{}
>
{
public
_isCompareMount
:
boolean
;
constructor
(
props
:
CompareProps
)
{
super
(
props
);
}
intermediate
=
()
=>
{
const
{
compareRows
}
=
this
.
props
;
const
trialIntermediate
:
Array
<
Intermedia
>
=
[];
const
idsList
:
Array
<
string
>
=
[];
Object
.
keys
(
compareRows
).
map
(
item
=>
{
const
temp
=
compareRows
[
item
];
trialIntermediate
.
push
({
name
:
temp
.
id
,
data
:
temp
.
description
.
intermediate
,
type
:
'
line
'
,
hyperPara
:
temp
.
description
.
parameters
});
idsList
.
push
(
temp
.
id
);
});
// find max intermediate number
trialIntermediate
.
sort
((
a
,
b
)
=>
{
return
(
b
.
data
.
length
-
a
.
data
.
length
);
});
const
legend
:
Array
<
string
>
=
[];
// max length
const
length
=
trialIntermediate
[
0
]
!==
undefined
?
trialIntermediate
[
0
].
data
.
length
:
0
;
const
xAxis
:
Array
<
number
>
=
[];
Object
.
keys
(
trialIntermediate
).
map
(
item
=>
{
const
temp
=
trialIntermediate
[
item
];
legend
.
push
(
temp
.
name
);
});
for
(
let
i
=
1
;
i
<=
length
;
i
++
)
{
xAxis
.
push
(
i
);
}
const
option
=
{
tooltip
:
{
trigger
:
'
item
'
,
enterable
:
true
,
position
:
function
(
point
:
Array
<
number
>
,
data
:
TooltipForIntermediate
)
{
if
(
data
.
dataIndex
<
length
/
2
)
{
return
[
point
[
0
],
80
];
}
else
{
return
[
point
[
0
]
-
300
,
80
];
}
},
formatter
:
function
(
data
:
TooltipForIntermediate
)
{
const
trialId
=
data
.
seriesName
;
let
obj
=
{};
const
temp
=
trialIntermediate
.
find
(
key
=>
key
.
name
===
trialId
);
if
(
temp
!==
undefined
)
{
obj
=
temp
.
hyperPara
;
}
return
'
<div class="tooldetailAccuracy">
'
+
'
<div>Trial ID:
'
+
trialId
+
'
</div>
'
+
'
<div>Intermediate:
'
+
data
.
data
+
'
</div>
'
+
'
<div>Parameters:
'
+
'
<pre>
'
+
JSON
.
stringify
(
obj
,
null
,
4
)
+
'
</pre>
'
+
'
</div>
'
+
'
</div>
'
;
}
},
grid
:
{
left
:
'
5%
'
,
top
:
40
,
containLabel
:
true
},
legend
:
{
data
:
idsList
},
xAxis
:
{
type
:
'
category
'
,
name
:
'
Step
'
,
boundaryGap
:
false
,
data
:
xAxis
},
yAxis
:
{
type
:
'
value
'
,
name
:
'
metric
'
},
series
:
trialIntermediate
};
return
(
<
ReactEcharts
option
=
{
option
}
style
=
{
{
width
:
'
100%
'
,
height
:
418
,
margin
:
'
0 auto
'
}
}
notMerge
=
{
true
}
// update now
/>
);
}
// render table column ---
initColumn
=
()
=>
{
const
{
compareRows
}
=
this
.
props
;
const
idList
:
Array
<
string
>
=
[];
const
durationList
:
Array
<
number
>
=
[];
const
parameterList
:
Array
<
object
>
=
[];
let
parameterKeys
:
Array
<
string
>
=
[];
if
(
compareRows
.
length
!==
0
)
{
parameterKeys
=
Object
.
keys
(
compareRows
[
0
].
description
.
parameters
);
}
Object
.
keys
(
compareRows
).
map
(
item
=>
{
const
temp
=
compareRows
[
item
];
idList
.
push
(
temp
.
id
);
durationList
.
push
(
temp
.
duration
);
parameterList
.
push
(
temp
.
description
.
parameters
);
});
return
(
<
table
className
=
"compare"
>
<
tbody
>
<
tr
>
<
td
/>
{
Object
.
keys
(
idList
).
map
(
key
=>
{
return
(
<
td
className
=
"value"
key
=
{
key
}
>
{
idList
[
key
]
}
</
td
>
);
})
}
</
tr
>
<
tr
>
<
td
className
=
"column"
>
Default metric
</
td
>
{
Object
.
keys
(
compareRows
).
map
(
index
=>
{
const
temp
=
compareRows
[
index
];
return
(
<
td
className
=
"value"
key
=
{
index
}
>
<
IntermediateVal
record
=
{
temp
}
/>
</
td
>
);
})
}
</
tr
>
<
tr
>
<
td
className
=
"column"
>
duration
</
td
>
{
Object
.
keys
(
durationList
).
map
(
index
=>
{
return
(
<
td
className
=
"value"
key
=
{
index
}
>
{
durationList
[
index
]
}
</
td
>
);
})
}
</
tr
>
{
Object
.
keys
(
parameterKeys
).
map
(
index
=>
{
return
(
<
tr
key
=
{
index
}
>
<
td
className
=
"column"
key
=
{
index
}
>
{
parameterKeys
[
index
]
}
</
td
>
{
Object
.
keys
(
parameterList
).
map
(
key
=>
{
return
(
<
td
key
=
{
key
}
className
=
"value"
>
{
parameterList
[
key
][
parameterKeys
[
index
]]
}
</
td
>
);
})
}
</
tr
>
);
})
}
</
tbody
>
</
table
>
);
}
componentDidMount
()
{
this
.
_isCompareMount
=
true
;
}
componentWillUnmount
()
{
this
.
_isCompareMount
=
false
;
}
render
()
{
const
{
visible
,
cancelFunc
}
=
this
.
props
;
return
(
<
Modal
title
=
"Compare trials"
visible
=
{
visible
}
onCancel
=
{
cancelFunc
}
footer
=
{
null
}
destroyOnClose
=
{
true
}
maskClosable
=
{
false
}
width
=
"90%"
>
<
Row
>
{
this
.
intermediate
()
}
</
Row
>
<
Row
>
{
this
.
initColumn
()
}
</
Row
>
</
Modal
>
);
}
}
export
default
Compare
;
src/webui/src/components/Overview.tsx
View file @
d6febf29
...
...
@@ -353,8 +353,10 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
const
indexarr
:
Array
<
number
>
=
[];
Object
.
keys
(
sourcePoint
).
map
(
item
=>
{
const
items
=
sourcePoint
[
item
];
if
(
items
.
acc
!==
undefined
)
{
accarr
.
push
(
items
.
acc
.
default
);
indexarr
.
push
(
items
.
sequenceId
);
}
});
const
accOption
=
{
// support max show 0.0000000
...
...
src/webui/src/components/SlideBar.tsx
View file @
d6febf29
...
...
@@ -29,7 +29,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
public
_isMounted
=
false
;
public
divMenu
:
HTMLDivElement
|
null
;
public
countOfMenu
:
number
=
0
;
public
selectHTML
:
Select
|
null
;
constructor
(
props
:
SliderProps
)
{
...
...
@@ -208,7 +207,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
menu
=
()
=>
{
this
.
countOfMenu
=
0
;
return
(
<
Menu
onClick
=
{
this
.
handleMenuClick
}
>
<
Menu
.
Item
key
=
"1"
>
Experiment Parameters
</
Menu
.
Item
>
...
...
@@ -223,7 +221,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
const
{
version
}
=
this
.
state
;
const
feedBackLink
=
`https://github.com/Microsoft/nni/issues/new?labels=
${
version
}
`
;
return
(
<
Menu
onClick
=
{
this
.
handleMenuClick
}
mode
=
"inline
"
>
<
Menu
onClick
=
{
this
.
handleMenuClick
}
className
=
"menuModal
"
>
<
Menu
.
Item
key
=
"overview"
><
Link
to
=
{
'
/oview
'
}
>
Overview
</
Link
></
Menu
.
Item
>
<
Menu
.
Item
key
=
"detail"
><
Link
to
=
{
'
/detail
'
}
>
Trials detail
</
Link
></
Menu
.
Item
>
<
Menu
.
Item
key
=
"fresh"
>
...
...
@@ -250,18 +248,6 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
// nav bar <1299
showMenu
=
()
=>
{
if
(
this
.
divMenu
!==
null
)
{
this
.
countOfMenu
=
this
.
countOfMenu
+
1
;
if
(
this
.
countOfMenu
%
2
===
0
)
{
this
.
divMenu
.
setAttribute
(
'
class
'
,
'
hide
'
);
}
else
{
this
.
divMenu
.
setAttribute
(
'
class
'
,
'
show
'
);
}
}
}
select
=
()
=>
{
return
(
<
Select
...
...
@@ -322,7 +308,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
</
li
>
<
li
className
=
"feedback"
>
<
span
className
=
"fresh"
onClick
=
{
this
.
fresh
}
>
<
Icon
type
=
"sync"
/><
span
>
Fresh
</
span
>
<
Icon
type
=
"sync"
/><
span
>
Fresh
</
span
>
</
span
>
<
Dropdown
className
=
"dropdown"
...
...
@@ -350,8 +336,9 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
<
MediaQuery
query
=
"(max-width: 1299px)"
>
<
Row
className
=
"little"
>
<
Col
span
=
{
6
}
className
=
"menu"
>
<
Icon
type
=
"unordered-list"
className
=
"more"
onClick
=
{
this
.
showMenu
}
/>
<
div
ref
=
{
div
=>
this
.
divMenu
=
div
}
className
=
"hide"
>
{
this
.
navigationBar
()
}
</
div
>
<
Dropdown
overlay
=
{
this
.
navigationBar
()
}
trigger
=
{
[
'
click
'
]
}
>
<
Icon
type
=
"unordered-list"
className
=
"more"
/>
</
Dropdown
>
</
Col
>
<
Col
span
=
{
10
}
className
=
"logo"
>
<
Link
to
=
{
'
/oview
'
}
>
...
...
src/webui/src/components/TrialsDetail.tsx
View file @
d6febf29
...
...
@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{
/* trial table list */
}
<
Title1
text
=
"Trial jobs"
icon
=
"6.png"
/>
<
Row
className
=
"allList"
>
<
Col
span
=
{
1
2
}
>
<
Col
span
=
{
1
0
}
>
<
span
>
Show
</
span
>
<
Select
className
=
"entry"
...
...
@@ -392,9 +392,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</
Select
>
<
span
>
entries
</
span
>
</
Col
>
<
Col
span
=
{
12
}
className
=
"right"
>
<
Row
>
<
Col
span
=
{
12
}
>
<
Col
span
=
{
14
}
className
=
"right"
>
<
Button
type
=
"primary"
className
=
"tableButton editStyle"
...
...
@@ -402,8 +400,14 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
>
Add column
</
Button
>
</
Col
>
<
Col
span
=
{
12
}
>
<
Button
type
=
"primary"
className
=
"tableButton editStyle mediateBtn"
// use child-component tableList's function, the function is in child-component.
onClick
=
{
this
.
tableList
?
this
.
tableList
.
compareBtn
:
this
.
test
}
>
Compare
</
Button
>
<
Input
type
=
"text"
placeholder
=
"Search by id, trial No. or status"
...
...
@@ -412,8 +416,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
/>
</
Col
>
</
Row
>
</
Col
>
</
Row
>
<
TableList
entries
=
{
entriesTable
}
tableSource
=
{
source
}
...
...
src/webui/src/components/trial-detail/Intermeidate.tsx
View file @
d6febf29
import
*
as
React
from
'
react
'
;
import
{
Row
,
Col
,
Button
,
Switch
}
from
'
antd
'
;
import
{
TooltipForIntermediate
,
TableObj
}
from
'
../../static/interface
'
;
import
{
TooltipForIntermediate
,
TableObj
,
Intermedia
}
from
'
../../static/interface
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
require
(
'
echarts/lib/component/tooltip
'
);
require
(
'
echarts/lib/component/title
'
);
interface
Intermedia
{
name
:
string
;
// id
type
:
string
;
data
:
Array
<
number
|
object
>
;
// intermediate data
hyperPara
:
object
;
// each trial hyperpara value
}
interface
IntermediateState
{
detailSource
:
Array
<
TableObj
>
;
interSource
:
object
;
...
...
src/webui/src/components/trial-detail/Para.tsx
View file @
d6febf29
...
...
@@ -145,7 +145,8 @@ class Para extends React.Component<ParaProps, ParaState> {
const
parallelAxis
:
Array
<
Dimobj
>
=
[];
// search space range and specific value [only number]
for
(
let
i
=
0
;
i
<
dimName
.
length
;
i
++
)
{
let
i
=
0
;
for
(
i
;
i
<
dimName
.
length
;
i
++
)
{
const
searchKey
=
searchRange
[
dimName
[
i
]];
switch
(
searchKey
.
_type
)
{
case
'
uniform
'
:
...
...
@@ -213,6 +214,13 @@ class Para extends React.Component<ParaProps, ParaState> {
}
}
parallelAxis
.
push
({
dim
:
i
,
name
:
'
default metric
'
,
nameTextStyle
:
{
fontWeight
:
700
}
});
if
(
lenOfDataSource
===
0
)
{
const
optionOfNull
=
{
parallelAxis
,
...
...
@@ -229,8 +237,8 @@ class Para extends React.Component<ParaProps, ParaState> {
const
length
=
value
.
length
;
if
(
length
>
16
)
{
const
temp
=
value
.
split
(
''
);
for
(
let
i
=
16
;
i
<
temp
.
length
;
i
+=
17
)
{
temp
[
i
]
+=
'
\n
'
;
for
(
let
m
=
16
;
m
<
temp
.
length
;
m
+=
17
)
{
temp
[
m
]
+=
'
\n
'
;
}
return
temp
.
join
(
''
);
}
else
{
...
...
src/webui/src/components/trial-detail/TableList.tsx
View file @
d6febf29
import
*
as
React
from
'
react
'
;
import
axios
from
'
axios
'
;
import
ReactEcharts
from
'
echarts-for-react
'
;
import
{
Row
,
Table
,
Button
,
Popconfirm
,
Modal
,
Checkbox
}
from
'
antd
'
;
import
{
Row
,
Table
,
Button
,
Popconfirm
,
Modal
,
Checkbox
,
Select
}
from
'
antd
'
;
const
Option
=
Select
.
Option
;
const
CheckboxGroup
=
Checkbox
.
Group
;
import
{
MANAGER_IP
,
trialJobStatus
,
COLUMN
,
COLUMN_INDEX
}
from
'
../../static/const
'
;
import
{
convertDuration
,
intermediateGraphOption
,
killJob
}
from
'
../../static/function
'
;
import
{
TableObj
,
TrialJob
}
from
'
../../static/interface
'
;
import
OpenRow
from
'
../public-child/OpenRow
'
;
import
Compare
from
'
../Modal/Compare
'
;
import
IntermediateVal
from
'
../public-child/IntermediateVal
'
;
// table default metric column
import
'
../../static/style/search.scss
'
;
require
(
'
../../static/style/tableStatus.css
'
);
...
...
@@ -38,6 +40,12 @@ interface TableListState {
isObjFinal
:
boolean
;
isShowColumn
:
boolean
;
columnSelected
:
Array
<
string
>
;
// user select columnKeys
selectRows
:
Array
<
TableObj
>
;
isShowCompareModal
:
boolean
;
selectedRowKeys
:
string
[]
|
number
[];
intermediateData
:
Array
<
object
>
;
// a trial's intermediate results (include dict)
intermediateId
:
string
;
intermediateOtherKeys
:
Array
<
string
>
;
}
interface
ColumnIndex
{
...
...
@@ -50,6 +58,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
public
_isMounted
=
false
;
public
intervalTrialLog
=
10
;
public
_trialId
:
string
;
public
tables
:
Table
<
TableObj
>
|
null
;
constructor
(
props
:
TableListProps
)
{
super
(
props
);
...
...
@@ -59,7 +68,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
modalVisible
:
false
,
isObjFinal
:
false
,
isShowColumn
:
false
,
columnSelected
:
COLUMN
isShowCompareModal
:
false
,
columnSelected
:
COLUMN
,
selectRows
:
[],
selectedRowKeys
:
[],
// close selected trial message after modal closed
intermediateData
:
[],
intermediateId
:
''
,
intermediateOtherKeys
:
[]
};
}
...
...
@@ -71,7 +86,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
.
then
(
res
=>
{
if
(
res
.
status
===
200
)
{
const
intermediateArr
:
number
[]
=
[];
// support intermediate result is dict
// support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array
let
otherkeys
:
Array
<
string
>
=
[
'
default
'
];
if
(
res
.
data
.
length
!==
0
)
{
otherkeys
=
Object
.
keys
(
JSON
.
parse
(
res
.
data
[
0
].
data
));
}
// intermediateArr just store default val
Object
.
keys
(
res
.
data
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
res
.
data
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
...
...
@@ -83,7 +105,10 @@ class TableList extends React.Component<TableListProps, TableListState> {
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
id
);
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
intermediateOption
:
intermediate
intermediateData
:
res
.
data
,
// store origin intermediate data for a trial
intermediateOption
:
intermediate
,
intermediateOtherKeys
:
otherkeys
,
intermediateId
:
id
}));
}
}
...
...
@@ -95,6 +120,38 @@ class TableList extends React.Component<TableListProps, TableListState> {
}
}
selectOtherKeys
=
(
value
:
string
)
=>
{
const
isShowDefault
:
boolean
=
value
===
'
default
'
?
true
:
false
;
const
{
intermediateData
,
intermediateId
}
=
this
.
state
;
const
intermediateArr
:
number
[]
=
[];
// just watch default key-val
if
(
isShowDefault
===
true
)
{
Object
.
keys
(
intermediateData
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
intermediateData
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
intermediateArr
.
push
(
temp
[
value
]);
}
else
{
intermediateArr
.
push
(
temp
);
}
});
}
else
{
Object
.
keys
(
intermediateData
).
map
(
item
=>
{
const
temp
=
JSON
.
parse
(
intermediateData
[
item
].
data
);
if
(
typeof
temp
===
'
object
'
)
{
intermediateArr
.
push
(
temp
[
value
]);
}
});
}
const
intermediate
=
intermediateGraphOption
(
intermediateArr
,
intermediateId
);
// re-render
if
(
this
.
_isMounted
)
{
this
.
setState
(()
=>
({
intermediateOption
:
intermediate
}));
}
}
hideIntermediateModal
=
()
=>
{
if
(
this
.
_isMounted
)
{
this
.
setState
({
...
...
@@ -184,6 +241,31 @@ class TableList extends React.Component<TableListProps, TableListState> {
);
}
fillSelectedRowsTostate
=
(
selected
:
number
[]
|
string
[],
selectedRows
:
Array
<
TableObj
>
)
=>
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
(()
=>
({
selectRows
:
selectedRows
,
selectedRowKeys
:
selected
}));
}
}
// open Compare-modal
compareBtn
=
()
=>
{
const
{
selectRows
}
=
this
.
state
;
if
(
selectRows
.
length
===
0
)
{
alert
(
'
Please select datas you want to compare!
'
);
}
else
{
if
(
this
.
_isMounted
===
true
)
{
this
.
setState
({
isShowCompareModal
:
true
});
}
}
}
// close Compare-modal
hideCompareModal
=
()
=>
{
// close modal. clear select rows data, clear selected track
if
(
this
.
_isMounted
)
{
this
.
setState
({
isShowCompareModal
:
false
,
selectedRowKeys
:
[],
selectRows
:
[]
});
}
}
componentDidMount
()
{
this
.
_isMounted
=
true
;
}
...
...
@@ -195,7 +277,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
render
()
{
const
{
entries
,
tableSource
,
updateList
}
=
this
.
props
;
const
{
intermediateOption
,
modalVisible
,
isShowColumn
,
columnSelected
}
=
this
.
state
;
const
{
intermediateOption
,
modalVisible
,
isShowColumn
,
columnSelected
,
selectRows
,
isShowCompareModal
,
selectedRowKeys
,
intermediateOtherKeys
}
=
this
.
state
;
const
rowSelection
=
{
selectedRowKeys
:
selectedRowKeys
,
onChange
:
(
selected
:
string
[]
|
number
[],
selectedRows
:
Array
<
TableObj
>
)
=>
{
this
.
fillSelectedRowsTostate
(
selected
,
selectedRows
);
}
};
let
showTitle
=
COLUMN
;
let
bgColor
=
''
;
const
trialJob
:
Array
<
TrialJob
>
=
[];
...
...
@@ -417,7 +506,9 @@ class TableList extends React.Component<TableListProps, TableListState> {
<
Row
className
=
"tableList"
>
<
div
id
=
"tableList"
>
<
Table
ref
=
{
(
table
:
Table
<
TableObj
>
|
null
)
=>
this
.
tables
=
table
}
columns
=
{
showColumn
}
rowSelection
=
{
rowSelection
}
expandedRowRender
=
{
this
.
openRow
}
dataSource
=
{
tableSource
}
className
=
"commonTableStyle"
...
...
@@ -432,6 +523,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
destroyOnClose
=
{
true
}
width
=
"80%"
>
{
intermediateOtherKeys
.
length
>
1
?
<
Row
className
=
"selectKeys"
>
<
Select
className
=
"select"
defaultValue
=
"default"
onSelect
=
{
this
.
selectOtherKeys
}
>
{
Object
.
keys
(
intermediateOtherKeys
).
map
(
item
=>
{
const
keys
=
intermediateOtherKeys
[
item
];
return
<
Option
value
=
{
keys
}
key
=
{
item
}
>
{
keys
}
</
Option
>;
})
}
</
Select
>
</
Row
>
:
<
div
/>
}
<
ReactEcharts
option
=
{
intermediateOption
}
style
=
{
{
...
...
@@ -458,6 +570,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
className
=
"titleColumn"
/>
</
Modal
>
<
Compare
compareRows
=
{
selectRows
}
visible
=
{
isShowCompareModal
}
cancelFunc
=
{
this
.
hideCompareModal
}
/>
</
Row
>
);
}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment