Unverified Commit e21a6984 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

[v2.0] Refactor code hierarchy (part 2) (#2987)

parent f98ee672
...@@ -107,7 +107,7 @@ class Mutator(BaseMutator): ...@@ -107,7 +107,7 @@ class Mutator(BaseMutator):
""" """
if not torch.__version__.startswith("1.4"): if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.") logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from nni._graph_utils import build_graph from nni.common.graph_utils import build_graph
from google.protobuf import json_format from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed # protobuf should be installed as long as tensorboard is installed
try: try:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from datetime import datetime from datetime import datetime
from io import TextIOBase from io import TextIOBase
import logging import logging
import os
import sys import sys
import time import time
...@@ -33,6 +34,9 @@ def init_logger(logger_file_path, log_level_name='info'): ...@@ -33,6 +34,9 @@ def init_logger(logger_file_path, log_level_name='info'):
This will redirect anything from logging.getLogger() as well as stdout to specified file. This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object). logger_file_path: path of logger file (path-like object).
""" """
if os.environ.get('NNI_PLATFORM') == 'unittest':
return # fixme: launching logic needs refactor
log_level = log_level_map.get(log_level_name, logging.INFO) log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w') logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
......
...@@ -8,10 +8,10 @@ import json_tricks ...@@ -8,10 +8,10 @@ import json_tricks
from nni import NoMoreTrialError from nni import NoMoreTrialError
from .protocol import CommandType, send from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .utils import MetricType, to_json from ..utils import MetricType, to_json
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -9,8 +9,8 @@ import json_tricks ...@@ -9,8 +9,8 @@ import json_tricks
from .common import multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .utils import init_dispatcher_logger from ..utils import init_dispatcher_logger
from .recoverable import Recoverable from ..recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
init_dispatcher_logger() init_dispatcher_logger()
......
...@@ -9,7 +9,7 @@ import subprocess ...@@ -9,7 +9,7 @@ import subprocess
from ..common import init_logger from ..common import init_logger
from ..env_vars import trial_env_vars from ..env_vars import trial_env_vars
from ..utils import to_json from nni.utils import to_json
_sysdir = trial_env_vars.NNI_SYS_DIR _sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')): if not os.path.exists(os.path.join(_sysdir, '.nni')):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import os
import threading import threading
from enum import Enum from enum import Enum
...@@ -27,8 +28,9 @@ class CommandType(Enum): ...@@ -27,8 +28,9 @@ class CommandType(Enum):
_lock = threading.Lock() _lock = threading.Lock()
try: try:
_in_file = open(3, 'rb') if os.environ.get('NNI_PLATFORM') != 'unittest':
_out_file = open(4, 'wb') _in_file = open(3, 'rb')
_out_file = open(4, 'wb')
except OSError: except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?' _msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
logging.getLogger(__name__).warning(_msg) logging.getLogger(__name__).warning(_msg)
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
import numpy as np import numpy as np
from .env_vars import trial_env_vars from .runtime.env_vars import trial_env_vars
from . import trial from . import trial
from . import parameter_expressions as param_exp from . import parameter_expressions as param_exp
from .nas_utils import classic_mode, enas_mode, oneshot_mode, darts_mode from .common.nas_utils import classic_mode, enas_mode, oneshot_mode, darts_mode
__all__ = [ __all__ = [
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ast import ast
import astor import astor
from nni.nni_cmd.common_utils import print_warning from nni.tools.nnictl.common_utils import print_warning
from .utils import ast_Num, ast_Str from .utils import ast_Num, ast_Str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment