Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0663218b
Unverified
Commit
0663218b
authored
Apr 22, 2019
by
SparkSnail
Committed by
GitHub
Apr 22, 2019
Browse files
Merge pull request #163 from Microsoft/master
merge master
parents
6c9360a5
cf983800
Changes
116
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
249 additions
and
68 deletions
+249
-68
src/nni_manager/training_service/test/localTrainingService.test.ts
...anager/training_service/test/localTrainingService.test.ts
+1
-1
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+4
-1
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+32
-0
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+5
-29
src/sdk/pynni/nni/env_vars.py
src/sdk/pynni/nni/env_vars.py
+48
-0
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
+3
-1
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
+29
-3
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+3
-1
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+25
-0
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+24
-3
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+17
-10
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+10
-4
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
+10
-0
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
+9
-0
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
.../pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
+4
-0
src/sdk/pynni/nni/platform/__init__.py
src/sdk/pynni/nni/platform/__init__.py
+5
-5
src/sdk/pynni/nni/platform/local.py
src/sdk/pynni/nni/platform/local.py
+15
-9
src/sdk/pynni/nni/protocol.py
src/sdk/pynni/nni/protocol.py
+1
-0
src/sdk/pynni/nni/recoverable.py
src/sdk/pynni/nni/recoverable.py
+1
-1
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
+3
-0
No files found.
src/nni_manager/training_service/test/localTrainingService.test.ts
View file @
0663218b
...
...
@@ -31,7 +31,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import
{
LocalTrainingService
}
from
'
../local/localTrainingService
'
;
// TODO: copy mockedTrail.py to local folder
const
localCodeDir
:
string
=
tmp
.
dirSync
().
name
const
localCodeDir
:
string
=
tmp
.
dirSync
().
name
.
split
(
'
\\
'
).
join
(
'
\\\\
'
);
const
mockedTrialPath
:
string
=
'
./training_service/test/mockedTrial.py
'
fs
.
copyFileSync
(
mockedTrialPath
,
localCodeDir
+
'
/mockedTrial.py
'
)
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
0663218b
...
...
@@ -94,4 +94,7 @@ class BatchTuner(Tuner):
return
self
.
values
[
self
.
count
]
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
pass
\ No newline at end of file
pass
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
0663218b
...
...
@@ -573,3 +573,35 @@ class BOHB(MsgDispatcherBase):
def
handle_add_customized_trial
(
self
,
data
):
pass
def
handle_import_data
(
self
,
data
):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
Raises
------
AssertionError
data doesn't have required key 'parameter' and 'value'
"""
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
%
(
_completed_num
),
len
(
data
))
_completed_num
+=
1
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
_KEY
not
in
_params
:
_params
[
_KEY
]
=
self
.
max_budget
logger
.
info
(
"Set
\"
TRIAL_BUDGET
\"
value to %s (max budget)"
%
self
.
max_budget
)
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
reward
=
-
_value
else
:
reward
=
_value
_budget
=
_params
[
_KEY
]
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
_budget
,
parameters
=
_params
,
update_model
=
True
)
logger
.
info
(
"Successfully import tuning data to BOHB advisor."
)
src/sdk/pynni/nni/common.py
View file @
0663218b
...
...
@@ -19,29 +19,13 @@
# ==================================================================================================
from
collections
import
namedtuple
from
datetime
import
datetime
from
io
import
TextIOBase
import
logging
import
os
import
sys
import
time
def
_load_env_args
():
args
=
{
'platform'
:
os
.
environ
.
get
(
'NNI_PLATFORM'
),
'trial_job_id'
:
os
.
environ
.
get
(
'NNI_TRIAL_JOB_ID'
),
'log_dir'
:
os
.
environ
.
get
(
'NNI_LOG_DIRECTORY'
),
'role'
:
os
.
environ
.
get
(
'NNI_ROLE'
),
'log_level'
:
os
.
environ
.
get
(
'NNI_LOG_LEVEL'
)
}
return
namedtuple
(
'EnvArgs'
,
args
.
keys
())(
**
args
)
env_args
=
_load_env_args
()
'''Arguments passed from environment'''
logLevelMap
=
{
log_level_map
=
{
'fatal'
:
logging
.
FATAL
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
...
...
@@ -49,7 +33,8 @@ logLevelMap = {
'debug'
:
logging
.
DEBUG
}
_time_format
=
'%m/%d/%Y, %I:%M:%S %P'
_time_format
=
'%m/%d/%Y, %I:%M:%S %p'
class
_LoggerFileWrapper
(
TextIOBase
):
def
__init__
(
self
,
logger_file
):
self
.
file
=
logger_file
...
...
@@ -61,21 +46,12 @@ class _LoggerFileWrapper(TextIOBase):
self
.
file
.
flush
()
return
len
(
s
)
def
init_logger
(
logger_file_path
):
def
init_logger
(
logger_file_path
,
log_level_name
=
'info'
):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
if
env_args
.
platform
==
'unittest'
:
logger_file_path
=
'unittest.log'
elif
env_args
.
log_dir
is
not
None
:
logger_file_path
=
os
.
path
.
join
(
env_args
.
log_dir
,
logger_file_path
)
if
env_args
.
log_level
and
logLevelMap
.
get
(
env_args
.
log_level
):
log_level
=
logLevelMap
[
env_args
.
log_level
]
else
:
log_level
=
logging
.
INFO
#default log level is INFO
log_level
=
log_level_map
.
get
(
log_level_name
,
logging
.
INFO
)
logger_file
=
open
(
logger_file_path
,
'w'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
...
...
src/sdk/pynni/nni/env_vars.py
0 → 100644
View file @
0663218b
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
from
collections
import
namedtuple
_trial_env_var_names
=
[
'NNI_PLATFORM'
,
'NNI_TRIAL_JOB_ID'
,
'NNI_SYS_DIR'
,
'NNI_OUTPUT_DIR'
,
'NNI_TRIAL_SEQ_ID'
,
'MULTI_PHASE'
]
_dispatcher_env_var_names
=
[
'NNI_MODE'
,
'NNI_CHECKPOINT_DIRECTORY'
,
'NNI_LOG_DIRECTORY'
,
'NNI_LOG_LEVEL'
,
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
]
def
_load_env_vars
(
env_var_names
):
env_var_dict
=
{
k
:
os
.
environ
.
get
(
k
)
for
k
in
env_var_names
}
return
namedtuple
(
'EnvVars'
,
env_var_names
)(
**
env_var_dict
)
trial_env_vars
=
_load_env_vars
(
_trial_env_var_names
)
dispatcher_env_vars
=
_load_env_vars
(
_dispatcher_env_var_names
)
src/sdk/pynni/nni/evolution_tuner/evolution_tuner.py
View file @
0663218b
...
...
@@ -34,7 +34,6 @@ from nni.tuner import Tuner
from
nni.utils
import
extract_scalar_reward
from
..
import
parameter_expressions
@
unique
class
OptimizeMode
(
Enum
):
"""Optimize Mode class
...
...
@@ -299,3 +298,6 @@ class EvolutionTuner(Tuner):
indiv
=
Individual
(
config
=
params
,
result
=
reward
)
self
.
population
.
append
(
indiv
)
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
View file @
0663218b
...
...
@@ -24,14 +24,17 @@ gridsearch_tuner.py including:
import
copy
import
numpy
as
np
import
logging
import
nni
from
nni.tuner
import
Tuner
from
nni.utils
import
convert_dict2tuple
TYPE
=
'_type'
CHOICE
=
'choice'
VALUE
=
'_value'
logger
=
logging
.
getLogger
(
'grid_search_AutoML'
)
class
GridSearchTuner
(
Tuner
):
'''
...
...
@@ -51,6 +54,7 @@ class GridSearchTuner(Tuner):
def
__init__
(
self
):
self
.
count
=
-
1
self
.
expanded_search_space
=
[]
self
.
supplement_data
=
dict
()
def
json2paramater
(
self
,
ss_spec
):
'''
...
...
@@ -135,9 +139,31 @@ class GridSearchTuner(Tuner):
def
generate_parameters
(
self
,
parameter_id
):
self
.
count
+=
1
if
self
.
count
>
len
(
self
.
expanded_search_space
)
-
1
:
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
return
self
.
expanded_search_space
[
self
.
count
]
while
(
self
.
count
<=
len
(
self
.
expanded_search_space
)
-
1
):
_params_tuple
=
convert_dict2tuple
(
self
.
expanded_search_space
[
self
.
count
])
if
_params_tuple
in
self
.
supplement_data
:
self
.
count
+=
1
else
:
return
self
.
expanded_search_space
[
self
.
count
]
raise
nni
.
NoMoreTrialError
(
'no more parameters now.'
)
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
):
pass
def
import_data
(
self
,
data
):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
%
(
_completed_num
),
len
(
data
))
_completed_num
+=
1
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
_params_tuple
=
convert_dict2tuple
(
_params
)
self
.
supplement_data
[
_params_tuple
]
=
True
logger
.
info
(
"Successfully import data to grid search tuner."
)
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
0663218b
...
...
@@ -31,7 +31,6 @@ import json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.common
import
init_logger
from
nni.utils
import
extract_scalar_reward
from
..
import
parameter_expressions
...
...
@@ -420,3 +419,6 @@ class Hyperband(MsgDispatcherBase):
def
handle_add_customized_trial
(
self
,
data
):
pass
def
handle_import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
0663218b
...
...
@@ -172,6 +172,7 @@ class HyperoptTuner(Tuner):
self
.
json
=
None
self
.
total_data
=
{}
self
.
rval
=
None
self
.
supplement_data_num
=
0
def
_choose_tuner
(
self
,
algorithm_name
):
"""
...
...
@@ -353,3 +354,27 @@ class HyperoptTuner(Tuner):
# remove '_index' from json2parameter and save params-id
total_params
=
json2parameter
(
self
.
json
,
parameter
)
return
total_params
def
import_data
(
self
,
data
):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
%
(
_completed_num
),
len
(
data
))
_completed_num
+=
1
if
self
.
algorithm_name
==
'random_search'
:
return
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
self
.
supplement_data_num
+=
1
_parameter_id
=
'_'
.
join
([
"ImportData"
,
str
(
self
.
supplement_data_num
)])
self
.
total_data
[
_parameter_id
]
=
_params
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
_params
,
value
=
_value
)
logger
.
info
(
"Successfully import data to TPE/Anneal tuner."
)
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
0663218b
...
...
@@ -96,7 +96,7 @@ class MetisTuner(Tuner):
self
.
samples_x
=
[]
self
.
samples_y
=
[]
self
.
samples_y_aggregation
=
[]
self
.
history_parameters
=
[]
self
.
total_data
=
[]
self
.
space
=
None
self
.
no_resampling
=
no_resampling
self
.
no_candidates
=
no_candidates
...
...
@@ -107,6 +107,7 @@ class MetisTuner(Tuner):
self
.
exploration_probability
=
exploration_probability
self
.
minimize_constraints_fun
=
None
self
.
minimize_starting_points
=
None
self
.
supplement_data_num
=
0
def
update_search_space
(
self
,
search_space
):
...
...
@@ -392,15 +393,35 @@ class MetisTuner(Tuner):
# ===== STEP 7: If current optimal hyperparameter occurs in the history or exploration probability is less than the threshold, take next config as exploration step =====
outputs
=
self
.
_pack_output
(
lm_current
[
'hyperparameter'
])
ap
=
random
.
uniform
(
0
,
1
)
if
outputs
in
self
.
history_parameters
or
ap
<=
self
.
exploration_probability
:
if
outputs
in
self
.
total_data
or
ap
<=
self
.
exploration_probability
:
if
next_candidate
is
not
None
:
outputs
=
self
.
_pack_output
(
next_candidate
[
'hyperparameter'
])
else
:
random_parameter
=
_rand_init
(
x_bounds
,
x_types
,
1
)[
0
]
outputs
=
self
.
_pack_output
(
random_parameter
)
self
.
history_parameters
.
append
(
outputs
)
self
.
total_data
.
append
(
outputs
)
return
outputs
def
import_data
(
self
,
data
):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
%
(
_completed_num
),
len
(
data
))
_completed_num
+=
1
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
self
.
supplement_data_num
+=
1
_parameter_id
=
'_'
.
join
([
"ImportData"
,
str
(
self
.
supplement_data_num
)])
self
.
total_data
.
append
(
_params
)
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
_params
,
value
=
_value
)
logger
.
info
(
"Successfully import data to metis tuner."
)
def
_rand_with_constraints
(
x_bounds
,
x_types
):
outputs
=
None
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
0663218b
...
...
@@ -27,6 +27,7 @@ from .protocol import CommandType, send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -108,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase):
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
def
handle_import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
# data: parameters
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
data
,
customized
=
True
))
def
handle_report_metric_data
(
self
,
data
):
"""
:param
data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if
data
[
'type'
]
==
'FINAL'
:
self
.
_handle_final_metric_data
(
data
)
...
...
@@ -134,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
"""
data: it has three keys: trial_job_id, event, hyper_params
trial_job_id: the id generated by training service
event: the job's state
hyper_params: the hyperparameters generated and returned by tuner
-
trial_job_id: the id generated by training service
-
event: the job's state
-
hyper_params: the hyperparameters generated and returned by tuner
"""
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
...
...
@@ -190,8 +197,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dumps
(
trial_job_id
))
# notify tuner
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
os
.
environ
.
get
(
'
NNI_INCLUDE_INTERMEDIATE_RESULTS
'
)
)
if
os
.
environ
.
get
(
'
NNI_INCLUDE_INTERMEDIATE_RESULTS
'
)
==
'true'
:
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
)
if
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
==
'true'
:
self
.
_earlystop_notify_tuner
(
data
)
else
:
_logger
.
debug
(
'GOOD'
)
...
...
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
0663218b
...
...
@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool
from
queue
import
Queue
,
Empty
import
json_tricks
from
.common
import
init_logger
,
multi_thread_enabled
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
from
.utils
import
init_dispatcher_logger
from
.recoverable
import
Recoverable
from
.protocol
import
CommandType
,
receive
init_logger
(
'dispatcher.log'
)
init_dispatcher_logger
()
_logger
=
logging
.
getLogger
(
__name__
)
QUEUE_LEN_WARNING_MARK
=
20
...
...
@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable):
This function will never return unless raise.
"""
_logger
.
info
(
'Start dispatcher'
)
mode
=
os
.
getenv
(
'NNI_MODE'
)
if
mode
==
'resume'
:
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
self
.
load_checkpoint
()
while
True
:
...
...
@@ -142,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
CommandType
.
Initialize
:
self
.
handle_initialize
,
CommandType
.
RequestTrialJobs
:
self
.
handle_request_trial_jobs
,
CommandType
.
UpdateSearchSpace
:
self
.
handle_update_search_space
,
CommandType
.
ImportData
:
self
.
handle_import_data
,
CommandType
.
AddCustomizedTrialJob
:
self
.
handle_add_customized_trial
,
# Tunner/Assessor commands:
...
...
@@ -166,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
def
handle_update_search_space
(
self
,
data
):
raise
NotImplementedError
(
'handle_update_search_space not implemented'
)
def
handle_import_data
(
self
,
data
):
raise
NotImplementedError
(
'handle_import_data not implemented'
)
def
handle_add_customized_trial
(
self
,
data
):
raise
NotImplementedError
(
'handle_add_customized_trial not implemented'
)
...
...
src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
View file @
0663218b
...
...
@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self
.
tuner
.
update_search_space
(
data
)
return
True
def
handle_import_data
(
self
,
data
):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self
.
tuner
.
import_data
(
data
)
return
True
def
handle_add_customized_trial
(
self
,
data
):
# data: parameters
id_
=
_create_parameter_id
()
...
...
@@ -154,6 +161,9 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self
.
tuner
.
trial_end
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
,
trial_job_id
)
return
True
def
handle_import_data
(
self
,
data
):
pass
def
_handle_intermediate_metric_data
(
self
,
data
):
if
data
[
'type'
]
!=
'PERIODICAL'
:
return
True
...
...
src/sdk/pynni/nni/multi_phase/multi_phase_tuner.py
View file @
0663218b
...
...
@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
import_data
(
self
,
data
):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def
load_checkpoint
(
self
):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
...
...
@@ -95,3 +101,6 @@ class MultiPhaseTuner(Recoverable):
def
_on_error
(
self
):
pass
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/networkmorphism_tuner/networkmorphism_tuner.py
View file @
0663218b
...
...
@@ -307,3 +307,7 @@ class NetworkMorphismTuner(Tuner):
if
item
[
"model_id"
]
==
model_id
:
return
item
[
"metric_value"
]
return
None
def
import_data
(
self
,
data
):
pass
src/sdk/pynni/nni/platform/__init__.py
View file @
0663218b
...
...
@@ -21,13 +21,13 @@
# pylint: disable=wildcard-import
from
..
common
import
env_ar
g
s
from
..
env_vars
import
trial_
env_
v
ars
if
env_ar
g
s
.
platform
is
None
:
if
trial_
env_
v
ars
.
NNI_PLATFORM
is
None
:
from
.standalone
import
*
elif
env_ar
g
s
.
platform
==
'unittest'
:
elif
trial_
env_
v
ars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
elif
env_ar
g
s
.
platform
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
):
elif
trial_
env_
v
ars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
):
from
.local
import
*
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
env_ar
g
s
.
platform
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_
env_
v
ars
.
NNI_PLATFORM
)
src/sdk/pynni/nni/platform/local.py
View file @
0663218b
...
...
@@ -19,34 +19,36 @@
# ==================================================================================================
import
os
import
sys
import
json
import
time
import
json_tricks
import
subprocess
import
json_tricks
from
..common
import
init_logger
,
env_args
from
..common
import
init_logger
from
..env_vars
import
trial_env_vars
_sysdir
=
os
.
environ
[
'
NNI_SYS_DIR
'
]
_sysdir
=
trial_env_vars
.
NNI_SYS_DIR
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
_sysdir
,
'.nni'
)):
os
.
makedirs
(
os
.
path
.
join
(
_sysdir
,
'.nni'
))
_metric_file
=
open
(
os
.
path
.
join
(
_sysdir
,
'.nni'
,
'metrics'
),
'wb'
)
_outputdir
=
os
.
environ
[
'
NNI_OUTPUT_DIR
'
]
_outputdir
=
trial_env_vars
.
NNI_OUTPUT_DIR
if
not
os
.
path
.
exists
(
_outputdir
):
os
.
makedirs
(
_outputdir
)
_nni_platform
=
os
.
environ
[
'
NNI_PLATFORM
'
]
_nni_platform
=
trial_env_vars
.
NNI_PLATFORM
if
_nni_platform
==
'local'
:
_log_file_path
=
os
.
path
.
join
(
_outputdir
,
'trial.log'
)
init_logger
(
_log_file_path
)
_multiphase
=
os
.
environ
.
get
(
'
MULTI_PHASE
'
)
_multiphase
=
trial_env_vars
.
MULTI_PHASE
_param_index
=
0
def
request_next_parameter
():
metric
=
json_tricks
.
dumps
({
'trial_job_id'
:
env_ar
g
s
.
trial_job_id
,
'trial_job_id'
:
trial_
env_
v
ars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'REQUEST_PARAMETER'
,
'sequence'
:
0
,
'parameter_index'
:
_param_index
...
...
@@ -86,7 +88,11 @@ def send_metric(string):
assert
len
(
data
)
<
1000000
,
'Metric too long'
_metric_file
.
write
(
b
'ME%06d%b'
%
(
len
(
data
),
data
))
_metric_file
.
flush
()
subprocess
.
run
([
'touch'
,
_metric_file
.
name
],
check
=
True
)
if
sys
.
platform
==
"win32"
:
file
=
open
(
_metric_file
.
name
)
file
.
close
()
else
:
subprocess
.
run
([
'touch'
,
_metric_file
.
name
],
check
=
True
)
def
get_sequence_id
():
return
os
.
environ
[
'NNI_TRIAL_SEQ_ID'
]
\ No newline at end of file
return
trial_env_vars
.
NNI_TRIAL_SEQ_ID
src/sdk/pynni/nni/protocol.py
View file @
0663218b
...
...
@@ -30,6 +30,7 @@ class CommandType(Enum):
RequestTrialJobs
=
b
'GE'
ReportMetricData
=
b
'ME'
UpdateSearchSpace
=
b
'SS'
ImportData
=
b
'FD'
AddCustomizedTrialJob
=
b
'AD'
TrialEnd
=
b
'EN'
Terminate
=
b
'TE'
...
...
src/sdk/pynni/nni/recoverable.py
View file @
0663218b
...
...
@@ -24,7 +24,7 @@ class Recoverable:
def
load_checkpoint
(
self
):
pass
def
save_checkpont
(
self
):
def
save_checkpo
i
nt
(
self
):
pass
def
get_checkpoint_path
(
self
):
...
...
src/sdk/pynni/nni/smac_tuner/smac_tuner.py
View file @
0663218b
...
...
@@ -261,3 +261,6 @@ class SMACTuner(Tuner):
params
.
append
(
self
.
convert_loguniform_categorical
(
challenger
.
get_dictionary
()))
cnt
+=
1
return
params
def
import_data
(
self
,
data
):
pass
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment