Unverified Commit e21a6984 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

[v2.0] Refactor code hierarchy (part 2) (#2987)

parent f98ee672
......@@ -107,7 +107,7 @@ class Mutator(BaseMutator):
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from nni._graph_utils import build_graph
from nni.common.graph_utils import build_graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
......
......@@ -4,6 +4,7 @@
from datetime import datetime
from io import TextIOBase
import logging
import os
import sys
import time
......@@ -33,6 +34,9 @@ def init_logger(logger_file_path, log_level_name='info'):
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
if os.environ.get('NNI_PLATFORM') == 'unittest':
return # fixme: launching logic needs refactor
log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
......
......@@ -8,10 +8,10 @@ import json_tricks
from nni import NoMoreTrialError
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
from .utils import MetricType, to_json
from ..utils import MetricType, to_json
_logger = logging.getLogger(__name__)
......
......@@ -9,8 +9,8 @@ import json_tricks
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from .utils import init_dispatcher_logger
from .recoverable import Recoverable
from ..utils import init_dispatcher_logger
from ..recoverable import Recoverable
from .protocol import CommandType, receive
init_dispatcher_logger()
......
......@@ -9,7 +9,7 @@ import subprocess
from ..common import init_logger
from ..env_vars import trial_env_vars
from ..utils import to_json
from nni.utils import to_json
_sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')):
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
import os
import threading
from enum import Enum
......@@ -27,8 +28,9 @@ class CommandType(Enum):
_lock = threading.Lock()
try:
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
if os.environ.get('NNI_PLATFORM') != 'unittest':
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
logging.getLogger(__name__).warning(_msg)
......
......@@ -3,10 +3,10 @@
import numpy as np
from .env_vars import trial_env_vars
from .runtime.env_vars import trial_env_vars
from . import trial
from . import parameter_expressions as param_exp
from .nas_utils import classic_mode, enas_mode, oneshot_mode, darts_mode
from .common.nas_utils import classic_mode, enas_mode, oneshot_mode, darts_mode
__all__ = [
......
......@@ -3,7 +3,7 @@
import ast
import astor
from nni.nni_cmd.common_utils import print_warning
from nni.tools.nnictl.common_utils import print_warning
from .utils import ast_Num, ast_Str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment