Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
62a79828
Unverified
Commit
62a79828
authored
Apr 19, 2019
by
chicm-ms
Committed by
GitHub
Apr 19, 2019
Browse files
Refactor env var (#993)
* Refactoring environment variables
parent
ca99000d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
91 additions
and
56 deletions
+91
-56
src/sdk/pynni/nni/common.py
src/sdk/pynni/nni/common.py
+3
-28
src/sdk/pynni/nni/env_vars.py
src/sdk/pynni/nni/env_vars.py
+48
-0
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+0
-1
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+3
-2
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+6
-4
src/sdk/pynni/nni/platform/__init__.py
src/sdk/pynni/nni/platform/__init__.py
+5
-5
src/sdk/pynni/nni/platform/local.py
src/sdk/pynni/nni/platform/local.py
+9
-8
src/sdk/pynni/nni/smartparam.py
src/sdk/pynni/nni/smartparam.py
+2
-4
src/sdk/pynni/nni/trial.py
src/sdk/pynni/nni/trial.py
+3
-3
src/sdk/pynni/nni/utils.py
src/sdk/pynni/nni/utils.py
+12
-1
No files found.
src/sdk/pynni/nni/common.py
View file @
62a79828
...
...
@@ -19,29 +19,13 @@
# ==================================================================================================
from
collections
import
namedtuple
from
datetime
import
datetime
from
io
import
TextIOBase
import
logging
import
os
import
sys
import
time
def
_load_env_args
():
args
=
{
'platform'
:
os
.
environ
.
get
(
'NNI_PLATFORM'
),
'trial_job_id'
:
os
.
environ
.
get
(
'NNI_TRIAL_JOB_ID'
),
'log_dir'
:
os
.
environ
.
get
(
'NNI_LOG_DIRECTORY'
),
'role'
:
os
.
environ
.
get
(
'NNI_ROLE'
),
'log_level'
:
os
.
environ
.
get
(
'NNI_LOG_LEVEL'
)
}
return
namedtuple
(
'EnvArgs'
,
args
.
keys
())(
**
args
)
env_args
=
_load_env_args
()
'''Arguments passed from environment'''
logLevelMap
=
{
log_level_map
=
{
'fatal'
:
logging
.
FATAL
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
...
...
@@ -61,21 +45,12 @@ class _LoggerFileWrapper(TextIOBase):
self
.
file
.
flush
()
return
len
(
s
)
def
init_logger
(
logger_file_path
):
def
init_logger
(
logger_file_path
,
log_level_name
=
'info'
):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
if
env_args
.
platform
==
'unittest'
:
logger_file_path
=
'unittest.log'
elif
env_args
.
log_dir
is
not
None
:
logger_file_path
=
os
.
path
.
join
(
env_args
.
log_dir
,
logger_file_path
)
if
env_args
.
log_level
and
logLevelMap
.
get
(
env_args
.
log_level
):
log_level
=
logLevelMap
[
env_args
.
log_level
]
else
:
log_level
=
logging
.
INFO
#default log level is INFO
log_level
=
log_level_map
.
get
(
log_level_name
,
logging
.
INFO
)
logger_file
=
open
(
logger_file_path
,
'w'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
...
...
src/sdk/pynni/nni/env_vars.py
0 → 100644
View file @
62a79828
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
from
collections
import
namedtuple
_trial_env_var_names
=
[
'NNI_PLATFORM'
,
'NNI_TRIAL_JOB_ID'
,
'NNI_SYS_DIR'
,
'NNI_OUTPUT_DIR'
,
'NNI_TRIAL_SEQ_ID'
,
'MULTI_PHASE'
]
_dispatcher_env_var_names
=
[
'NNI_MODE'
,
'NNI_CHECKPOINT_DIRECTORY'
,
'NNI_LOG_DIRECTORY'
,
'NNI_LOG_LEVEL'
,
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
]
def
_load_env_vars
(
env_var_names
):
env_var_dict
=
{
k
:
os
.
environ
.
get
(
k
)
for
k
in
env_var_names
}
return
namedtuple
(
'EnvVars'
,
env_var_names
)(
**
env_var_dict
)
trial_env_vars
=
_load_env_vars
(
_trial_env_var_names
)
dispatcher_env_vars
=
_load_env_vars
(
_dispatcher_env_var_names
)
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
62a79828
...
...
@@ -31,7 +31,6 @@ import json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.common
import
init_logger
from
nni.utils
import
extract_scalar_reward
from
..
import
parameter_expressions
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
62a79828
...
...
@@ -27,6 +27,7 @@ from .protocol import CommandType, send
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -190,8 +191,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dumps
(
trial_job_id
))
# notify tuner
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
os
.
environ
.
get
(
'
NNI_INCLUDE_INTERMEDIATE_RESULTS
'
)
)
if
os
.
environ
.
get
(
'
NNI_INCLUDE_INTERMEDIATE_RESULTS
'
)
==
'true'
:
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
)
if
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
==
'true'
:
self
.
_earlystop_notify_tuner
(
data
)
else
:
_logger
.
debug
(
'GOOD'
)
...
...
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
62a79828
...
...
@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool
from
queue
import
Queue
,
Empty
import
json_tricks
from
.common
import
init_logger
,
multi_thread_enabled
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
from
.utils
import
init_dispatcher_logger
from
.recoverable
import
Recoverable
from
.protocol
import
CommandType
,
receive
init_logger
(
'dispatcher.log'
)
init_dispatcher_logger
()
_logger
=
logging
.
getLogger
(
__name__
)
QUEUE_LEN_WARNING_MARK
=
20
...
...
@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable):
This function will never return unless raise.
"""
_logger
.
info
(
'Start dispatcher'
)
mode
=
os
.
getenv
(
'NNI_MODE'
)
if
mode
==
'resume'
:
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
self
.
load_checkpoint
()
while
True
:
...
...
src/sdk/pynni/nni/platform/__init__.py
View file @
62a79828
...
...
@@ -21,13 +21,13 @@
# pylint: disable=wildcard-import
from
..
common
import
env_ar
g
s
from
..
env_vars
import
trial_
env_
v
ars
if
env_ar
g
s
.
platform
is
None
:
if
trial_
env_
v
ars
.
NNI_PLATFORM
is
None
:
from
.standalone
import
*
elif
env_ar
g
s
.
platform
==
'unittest'
:
elif
trial_
env_
v
ars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
elif
env_ar
g
s
.
platform
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
):
elif
trial_
env_
v
ars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
):
from
.local
import
*
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
env_ar
g
s
.
platform
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_
env_
v
ars
.
NNI_PLATFORM
)
src/sdk/pynni/nni/platform/local.py
View file @
62a79828
...
...
@@ -21,32 +21,33 @@
import
os
import
json
import
time
import
json_tricks
import
subprocess
import
json_tricks
from
..common
import
init_logger
,
env_args
from
..common
import
init_logger
from
..env_vars
import
trial_env_vars
_sysdir
=
os
.
environ
[
'
NNI_SYS_DIR
'
]
_sysdir
=
trial_env_vars
.
NNI_SYS_DIR
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
_sysdir
,
'.nni'
)):
os
.
makedirs
(
os
.
path
.
join
(
_sysdir
,
'.nni'
))
_metric_file
=
open
(
os
.
path
.
join
(
_sysdir
,
'.nni'
,
'metrics'
),
'wb'
)
_outputdir
=
os
.
environ
[
'
NNI_OUTPUT_DIR
'
]
_outputdir
=
trial_env_vars
.
NNI_OUTPUT_DIR
if
not
os
.
path
.
exists
(
_outputdir
):
os
.
makedirs
(
_outputdir
)
_nni_platform
=
os
.
environ
[
'
NNI_PLATFORM
'
]
_nni_platform
=
trial_env_vars
.
NNI_PLATFORM
if
_nni_platform
==
'local'
:
_log_file_path
=
os
.
path
.
join
(
_outputdir
,
'trial.log'
)
init_logger
(
_log_file_path
)
_multiphase
=
os
.
environ
.
get
(
'
MULTI_PHASE
'
)
_multiphase
=
trial_env_vars
.
MULTI_PHASE
_param_index
=
0
def
request_next_parameter
():
metric
=
json_tricks
.
dumps
({
'trial_job_id'
:
env_ar
g
s
.
trial_job_id
,
'trial_job_id'
:
trial_
env_
v
ars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'REQUEST_PARAMETER'
,
'sequence'
:
0
,
'parameter_index'
:
_param_index
...
...
@@ -89,4 +90,4 @@ def send_metric(string):
subprocess
.
run
([
'touch'
,
_metric_file
.
name
],
check
=
True
)
def
get_sequence_id
():
return
os
.
environ
[
'NNI_TRIAL_SEQ_ID'
]
\ No newline at end of file
return
trial_env_vars
.
NNI_TRIAL_SEQ_ID
src/sdk/pynni/nni/smartparam.py
View file @
62a79828
...
...
@@ -19,11 +19,9 @@
# ==================================================================================================
import
inspect
import
math
import
random
from
.
common
import
env_ar
g
s
from
.
env_vars
import
trial_
env_
v
ars
from
.
import
trial
...
...
@@ -44,7 +42,7 @@ __all__ = [
# pylint: disable=unused-argument
if
env_ar
g
s
.
platform
is
None
:
if
trial_
env_
v
ars
.
NNI_PLATFORM
is
None
:
def
choice
(
*
options
,
name
=
None
):
return
random
.
choice
(
options
)
...
...
src/sdk/pynni/nni/trial.py
View file @
62a79828
...
...
@@ -21,7 +21,7 @@
import
json_tricks
from
.
common
import
env_ar
g
s
from
.
env_vars
import
trial_
env_
v
ars
from
.
import
platform
...
...
@@ -65,7 +65,7 @@ def report_intermediate_result(metric):
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric
=
json_tricks
.
dumps
({
'parameter_id'
:
_params
[
'parameter_id'
],
'trial_job_id'
:
env_ar
g
s
.
trial_job_id
,
'trial_job_id'
:
trial_
env_
v
ars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'PERIODICAL'
,
'sequence'
:
_intermediate_seq
,
'value'
:
metric
...
...
@@ -81,7 +81,7 @@ def report_final_result(metric):
assert
_params
is
not
None
,
'nni.get_next_parameter() needs to be called before report_final_result'
metric
=
json_tricks
.
dumps
({
'parameter_id'
:
_params
[
'parameter_id'
],
'trial_job_id'
:
env_ar
g
s
.
trial_job_id
,
'trial_job_id'
:
trial_
env_
v
ars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'FINAL'
,
'sequence'
:
0
,
# TODO: may be unnecessary
'value'
:
metric
...
...
src/sdk/pynni/nni/utils.py
View file @
62a79828
...
...
@@ -18,6 +18,10 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import
os
from
.common
import
init_logger
from
.env_vars
import
dispatcher_env_vars
def
extract_scalar_reward
(
value
,
scalar_key
=
'default'
):
"""
Raises
...
...
@@ -32,4 +36,11 @@ def extract_scalar_reward(value, scalar_key='default'):
reward
=
value
[
scalar_key
]
else
:
raise
RuntimeError
(
'Incorrect final result: the final result for %s should be float/int, or a dict which has a key named "default" whose value is float/int.'
%
str
(
self
.
__class__
))
return
reward
\ No newline at end of file
return
reward
def
init_dispatcher_logger
():
""" Initialize dispatcher logging configuration"""
logger_file_path
=
'dispatcher.log'
if
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
is
not
None
:
logger_file_path
=
os
.
path
.
join
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
,
logger_file_path
)
init_logger
(
logger_file_path
,
dispatcher_env_vars
.
NNI_LOG_LEVEL
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment