Unverified Commit 62a79828 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor env var (#993)

* Refactoring environment variables
parent ca99000d
......@@ -19,29 +19,13 @@
# ==================================================================================================
from collections import namedtuple
from datetime import datetime
from io import TextIOBase
import logging
import os
import sys
import time
def _load_env_args():
args = {
'platform': os.environ.get('NNI_PLATFORM'),
'trial_job_id': os.environ.get('NNI_TRIAL_JOB_ID'),
'log_dir': os.environ.get('NNI_LOG_DIRECTORY'),
'role': os.environ.get('NNI_ROLE'),
'log_level': os.environ.get('NNI_LOG_LEVEL')
}
return namedtuple('EnvArgs', args.keys())(**args)
env_args = _load_env_args()
'''Arguments passed from environment'''
logLevelMap = {
log_level_map = {
'fatal': logging.FATAL,
'error': logging.ERROR,
'warning': logging.WARNING,
......@@ -61,21 +45,12 @@ class _LoggerFileWrapper(TextIOBase):
self.file.flush()
return len(s)
def init_logger(logger_file_path):
def init_logger(logger_file_path, log_level_name='info'):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
if env_args.platform == 'unittest':
logger_file_path = 'unittest.log'
elif env_args.log_dir is not None:
logger_file_path = os.path.join(env_args.log_dir, logger_file_path)
if env_args.log_level and logLevelMap.get(env_args.log_level):
log_level = logLevelMap[env_args.log_level]
else:
log_level = logging.INFO #default log level is INFO
log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
from collections import namedtuple
_trial_env_var_names = [
'NNI_PLATFORM',
'NNI_TRIAL_JOB_ID',
'NNI_SYS_DIR',
'NNI_OUTPUT_DIR',
'NNI_TRIAL_SEQ_ID',
'MULTI_PHASE'
]
_dispatcher_env_var_names = [
'NNI_MODE',
'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY',
'NNI_LOG_LEVEL',
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
]
def _load_env_vars(env_var_names):
env_var_dict = {k: os.environ.get(k) for k in env_var_names}
return namedtuple('EnvVars', env_var_names)(**env_var_dict)
trial_env_vars = _load_env_vars(_trial_env_var_names)
dispatcher_env_vars = _load_env_vars(_dispatcher_env_var_names)
......@@ -31,7 +31,6 @@ import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger
from nni.utils import extract_scalar_reward
from .. import parameter_expressions
......
......@@ -27,6 +27,7 @@ from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__)
......@@ -190,8 +191,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')
......
......@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks
from .common import init_logger, multi_thread_enabled
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from .utils import init_dispatcher_logger
from .recoverable import Recoverable
from .protocol import CommandType, receive
init_logger('dispatcher.log')
init_dispatcher_logger()
_logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
......@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable):
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE')
if mode == 'resume':
if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint()
while True:
......
......@@ -21,13 +21,13 @@
# pylint: disable=wildcard-import
from ..common import env_args
from ..env_vars import trial_env_vars
if env_args.platform is None:
if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif env_args.platform == 'unittest':
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif env_args.platform in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'):
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % env_args.platform)
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
......@@ -21,32 +21,33 @@
import os
import json
import time
import json_tricks
import subprocess
import json_tricks
from ..common import init_logger, env_args
from ..common import init_logger
from ..env_vars import trial_env_vars
_sysdir = os.environ['NNI_SYS_DIR']
_sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')):
os.makedirs(os.path.join(_sysdir, '.nni'))
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb')
_outputdir = os.environ['NNI_OUTPUT_DIR']
_outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir):
os.makedirs(_outputdir)
_nni_platform = os.environ['NNI_PLATFORM']
_nni_platform = trial_env_vars.NNI_PLATFORM
if _nni_platform == 'local':
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)
_multiphase = os.environ.get('MULTI_PHASE')
_multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0
def request_next_parameter():
metric = json_tricks.dumps({
'trial_job_id': env_args.trial_job_id,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER',
'sequence': 0,
'parameter_index': _param_index
......@@ -89,4 +90,4 @@ def send_metric(string):
subprocess.run(['touch', _metric_file.name], check = True)
def get_sequence_id():
return os.environ['NNI_TRIAL_SEQ_ID']
\ No newline at end of file
return trial_env_vars.NNI_TRIAL_SEQ_ID
......@@ -19,11 +19,9 @@
# ==================================================================================================
import inspect
import math
import random
from .common import env_args
from .env_vars import trial_env_vars
from . import trial
......@@ -44,7 +42,7 @@ __all__ = [
# pylint: disable=unused-argument
if env_args.platform is None:
if trial_env_vars.NNI_PLATFORM is None:
def choice(*options, name=None):
return random.choice(options)
......
......@@ -21,7 +21,7 @@
import json_tricks
from .common import env_args
from .env_vars import trial_env_vars
from . import platform
......@@ -65,7 +65,7 @@ def report_intermediate_result(metric):
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
'sequence': _intermediate_seq,
'value': metric
......@@ -81,7 +81,7 @@ def report_final_result(metric):
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result'
metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0, # TODO: may be unnecessary
'value': metric
......
......@@ -18,6 +18,10 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
from .common import init_logger
from .env_vars import dispatcher_env_vars
def extract_scalar_reward(value, scalar_key='default'):
"""
Raises
......@@ -32,4 +36,11 @@ def extract_scalar_reward(value, scalar_key='default'):
reward = value[scalar_key]
else:
raise RuntimeError('Incorrect final result: the final result for %s should be float/int, or a dict which has a key named "default" whose value is float/int.' % str(self.__class__))
return reward
\ No newline at end of file
return reward
def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment