"...resnet50_tensorflow.git" did not exist on "5bba62081257cc0df8a07131acc416c8ca7dd4f7"
Unverified Commit 62a79828 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor env var (#993)

* Refactoring environment variables
parent ca99000d
...@@ -19,29 +19,13 @@ ...@@ -19,29 +19,13 @@
# ================================================================================================== # ==================================================================================================
from collections import namedtuple
from datetime import datetime from datetime import datetime
from io import TextIOBase from io import TextIOBase
import logging import logging
import os
import sys import sys
import time import time
log_level_map = {
def _load_env_args():
args = {
'platform': os.environ.get('NNI_PLATFORM'),
'trial_job_id': os.environ.get('NNI_TRIAL_JOB_ID'),
'log_dir': os.environ.get('NNI_LOG_DIRECTORY'),
'role': os.environ.get('NNI_ROLE'),
'log_level': os.environ.get('NNI_LOG_LEVEL')
}
return namedtuple('EnvArgs', args.keys())(**args)
env_args = _load_env_args()
'''Arguments passed from environment'''
logLevelMap = {
'fatal': logging.FATAL, 'fatal': logging.FATAL,
'error': logging.ERROR, 'error': logging.ERROR,
'warning': logging.WARNING, 'warning': logging.WARNING,
...@@ -61,21 +45,12 @@ class _LoggerFileWrapper(TextIOBase): ...@@ -61,21 +45,12 @@ class _LoggerFileWrapper(TextIOBase):
self.file.flush() self.file.flush()
return len(s) return len(s)
def init_logger(logger_file_path): def init_logger(logger_file_path, log_level_name='info'):
"""Initialize root logger. """Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file. This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object). logger_file_path: path of logger file (path-like object).
""" """
if env_args.platform == 'unittest': log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file_path = 'unittest.log'
elif env_args.log_dir is not None:
logger_file_path = os.path.join(env_args.log_dir, logger_file_path)
if env_args.log_level and logLevelMap.get(env_args.log_level):
log_level = logLevelMap[env_args.log_level]
else:
log_level = logging.INFO #default log level is INFO
logger_file = open(logger_file_path, 'w') logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime logging.Formatter.converter = time.localtime
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
from collections import namedtuple
_trial_env_var_names = [
'NNI_PLATFORM',
'NNI_TRIAL_JOB_ID',
'NNI_SYS_DIR',
'NNI_OUTPUT_DIR',
'NNI_TRIAL_SEQ_ID',
'MULTI_PHASE'
]
_dispatcher_env_var_names = [
'NNI_MODE',
'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY',
'NNI_LOG_LEVEL',
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
]
def _load_env_vars(env_var_names):
env_var_dict = {k: os.environ.get(k) for k in env_var_names}
return namedtuple('EnvVars', env_var_names)(**env_var_dict)
trial_env_vars = _load_env_vars(_trial_env_var_names)
dispatcher_env_vars = _load_env_vars(_dispatcher_env_var_names)
...@@ -31,7 +31,6 @@ import json_tricks ...@@ -31,7 +31,6 @@ import json_tricks
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
from .. import parameter_expressions from .. import parameter_expressions
......
...@@ -27,6 +27,7 @@ from .protocol import CommandType, send ...@@ -27,6 +27,7 @@ from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -190,8 +191,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -190,8 +191,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner # notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS')) _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true': if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data) self._earlystop_notify_tuner(data)
else: else:
_logger.debug('GOOD') _logger.debug('GOOD')
......
...@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool ...@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty from queue import Queue, Empty
import json_tricks import json_tricks
from .common import init_logger, multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from .utils import init_dispatcher_logger
from .recoverable import Recoverable from .recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
init_logger('dispatcher.log') init_dispatcher_logger()
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20 QUEUE_LEN_WARNING_MARK = 20
...@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable):
This function will never return unless raise. This function will never return unless raise.
""" """
_logger.info('Start dispatcher') _logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE') if dispatcher_env_vars.NNI_MODE == 'resume':
if mode == 'resume':
self.load_checkpoint() self.load_checkpoint()
while True: while True:
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from ..common import env_args from ..env_vars import trial_env_vars
if env_args.platform is None: if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif env_args.platform == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif env_args.platform in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % env_args.platform) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -21,32 +21,33 @@ ...@@ -21,32 +21,33 @@
import os import os
import json import json
import time import time
import json_tricks
import subprocess import subprocess
import json_tricks
from ..common import init_logger, env_args from ..common import init_logger
from ..env_vars import trial_env_vars
_sysdir = os.environ['NNI_SYS_DIR'] _sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')): if not os.path.exists(os.path.join(_sysdir, '.nni')):
os.makedirs(os.path.join(_sysdir, '.nni')) os.makedirs(os.path.join(_sysdir, '.nni'))
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb') _metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb')
_outputdir = os.environ['NNI_OUTPUT_DIR'] _outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir): if not os.path.exists(_outputdir):
os.makedirs(_outputdir) os.makedirs(_outputdir)
_nni_platform = os.environ['NNI_PLATFORM'] _nni_platform = trial_env_vars.NNI_PLATFORM
if _nni_platform == 'local': if _nni_platform == 'local':
_log_file_path = os.path.join(_outputdir, 'trial.log') _log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path) init_logger(_log_file_path)
_multiphase = os.environ.get('MULTI_PHASE') _multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0 _param_index = 0
def request_next_parameter(): def request_next_parameter():
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'trial_job_id': env_args.trial_job_id, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER', 'type': 'REQUEST_PARAMETER',
'sequence': 0, 'sequence': 0,
'parameter_index': _param_index 'parameter_index': _param_index
...@@ -89,4 +90,4 @@ def send_metric(string): ...@@ -89,4 +90,4 @@ def send_metric(string):
subprocess.run(['touch', _metric_file.name], check = True) subprocess.run(['touch', _metric_file.name], check = True)
def get_sequence_id(): def get_sequence_id():
return os.environ['NNI_TRIAL_SEQ_ID'] return trial_env_vars.NNI_TRIAL_SEQ_ID
\ No newline at end of file
...@@ -19,11 +19,9 @@ ...@@ -19,11 +19,9 @@
# ================================================================================================== # ==================================================================================================
import inspect
import math
import random import random
from .common import env_args from .env_vars import trial_env_vars
from . import trial from . import trial
...@@ -44,7 +42,7 @@ __all__ = [ ...@@ -44,7 +42,7 @@ __all__ = [
# pylint: disable=unused-argument # pylint: disable=unused-argument
if env_args.platform is None: if trial_env_vars.NNI_PLATFORM is None:
def choice(*options, name=None): def choice(*options, name=None):
return random.choice(options) return random.choice(options)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import json_tricks import json_tricks
from .common import env_args from .env_vars import trial_env_vars
from . import platform from . import platform
...@@ -65,7 +65,7 @@ def report_intermediate_result(metric): ...@@ -65,7 +65,7 @@ def report_intermediate_result(metric):
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result' assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
'sequence': _intermediate_seq, 'sequence': _intermediate_seq,
'value': metric 'value': metric
...@@ -81,7 +81,7 @@ def report_final_result(metric): ...@@ -81,7 +81,7 @@ def report_final_result(metric):
assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result' assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, # TODO: may be unnecessary 'sequence': 0, # TODO: may be unnecessary
'value': metric 'value': metric
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import os
from .common import init_logger
from .env_vars import dispatcher_env_vars
def extract_scalar_reward(value, scalar_key='default'): def extract_scalar_reward(value, scalar_key='default'):
""" """
Raises Raises
...@@ -32,4 +36,11 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -32,4 +36,11 @@ def extract_scalar_reward(value, scalar_key='default'):
reward = value[scalar_key] reward = value[scalar_key]
else: else:
raise RuntimeError('Incorrect final result: the final result for %s should be float/int, or a dict which has a key named "default" whose value is float/int.' % str(self.__class__)) raise RuntimeError('Incorrect final result: the final result for %s should be float/int, or a dict which has a key named "default" whose value is float/int.' % str(self.__class__))
return reward return reward
\ No newline at end of file
def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment