Unverified Commit b7c91e73 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix bugs and lints in nnictl (#3712)


Co-authored-by: default avatarliuzhe <zhe.liu@microsoft.com>
parent 259aee75
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import importlib import importlib
import json import json
from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \ from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \
......
...@@ -9,11 +9,11 @@ import time ...@@ -9,11 +9,11 @@ import time
import socket import socket
import string import string
import random import random
import yaml
import psutil
import filelock
import glob import glob
from colorama import Fore from colorama import Fore
import filelock
import psutil
import yaml
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
......
...@@ -6,12 +6,12 @@ import logging ...@@ -6,12 +6,12 @@ import logging
import os import os
import netifaces import netifaces
from schema import And, Optional, Or, Regex, Schema, SchemaError
from nni.tools.package_utils import ( from nni.tools.package_utils import (
create_validator_instance, create_validator_instance,
get_all_builtin_names, get_all_builtin_names,
get_registered_algo_meta, get_registered_algo_meta,
) )
from schema import And, Optional, Or, Regex, Schema, SchemaError
from .common_utils import get_yml_content, print_warning from .common_utils import get_yml_content, print_warning
from .constants import SCHEMA_PATH_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_TYPE_ERROR from .constants import SCHEMA_PATH_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_TYPE_ERROR
...@@ -625,7 +625,7 @@ class NNIConfigSchema: ...@@ -625,7 +625,7 @@ class NNIConfigSchema:
raise SchemaError("""If no taskRoles are specified a valid custom frameworkcontroller config should raise SchemaError("""If no taskRoles are specified a valid custom frameworkcontroller config should
be set using the configPath attribute in frameworkcontrollerConfig!""") be set using the configPath attribute in frameworkcontrollerConfig!""")
config_content = get_yml_content(experiment_config.get('frameworkcontrollerConfig').get('configPath')) config_content = get_yml_content(experiment_config.get('frameworkcontrollerConfig').get('configPath'))
if not config_content.get('spec').get('taskRoles') or not len(config_content.get('spec').get('taskRoles')): if not config_content.get('spec').get('taskRoles') or not config_content.get('spec').get('taskRoles'):
raise SchemaError('Invalid frameworkcontroller config! No taskRoles were specified!') raise SchemaError('Invalid frameworkcontroller config! No taskRoles were specified!')
if not config_content.get('spec').get('taskRoles')[0].get('task'): if not config_content.get('spec').get('taskRoles')[0].get('task'):
raise SchemaError('Invalid frameworkcontroller config! No task was specified for taskRole!') raise SchemaError('Invalid frameworkcontroller config! No task was specified for taskRole!')
......
...@@ -2,12 +2,9 @@ ...@@ -2,12 +2,9 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import os import os
import json_tricks
import shutil
import sqlite3 import sqlite3
import time import json_tricks
from .constants import NNI_HOME_DIR from .constants import NNI_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock from .common_utils import get_file_lock
def config_v0_to_v1(config: dict) -> dict: def config_v0_to_v1(config: dict) -> dict:
......
...@@ -17,10 +17,9 @@ from nni.tools.package_utils import get_builtin_module_class_name ...@@ -17,10 +17,9 @@ from nni.tools.package_utils import get_builtin_module_class_name
import nni_node # pylint: disable=import-error import nni_node # pylint: disable=import-error
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_port, get_user
detect_port, get_user
from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
...@@ -84,7 +83,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log ...@@ -84,7 +83,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
if url_prefix: if url_prefix:
_validate_prefix_path(url_prefix) _validate_prefix_path(url_prefix)
setPrefixUrl(url_prefix) set_prefix_url(url_prefix)
cmds += ['--url_prefix', url_prefix] cmds += ['--url_prefix', url_prefix]
stdout_full_path, stderr_full_path = get_log_path(experiment_id) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
...@@ -167,7 +166,8 @@ def set_V1_common_config(experiment_config, port, config_file_name): ...@@ -167,7 +166,8 @@ def set_V1_common_config(experiment_config, port, config_file_name):
response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT) response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT)
validate_response(response, config_file_name) validate_response(response, config_file_name)
if experiment_config.get('logCollection'): if experiment_config.get('logCollection'):
response = rest_put(cluster_metadata_url(port), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT) data = json.dumps({'log_collection': experiment_config.get('logCollection')})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
validate_response(response, config_file_name) validate_response(response, config_file_name)
def setNNIManagerIp(experiment_config, port, config_file_name): def setNNIManagerIp(experiment_config, port, config_file_name):
...@@ -229,7 +229,8 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name): ...@@ -229,7 +229,8 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
def set_shared_storage(experiment_config, port, config_file_name): def set_shared_storage(experiment_config, port, config_file_name):
if 'sharedStorage' in experiment_config: if 'sharedStorage' in experiment_config:
response = rest_put(cluster_metadata_url(port), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT) data = json.dumps({'shared_storage_config': experiment_config['sharedStorage']})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
err_message = None err_message = None
if not response or not response.status_code == 200: if not response or not response.status_code == 200:
if response is not None: if response is not None:
...@@ -485,7 +486,10 @@ def _validate_v2(config, path): ...@@ -485,7 +486,10 @@ def _validate_v2(config, path):
print_error(f'Config V2 validation failed: {repr(e)}') print_error(f'Config V2 validation failed: {repr(e)}')
def _validate_prefix_path(path): def _validate_prefix_path(path):
assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." assert not path.startswith('/'), 'URL prefix should not start with "/".'
parts = path.split('/')
valid = all(re.match('^[A-Za-z0-9_-]*$', part) for part in parts)
assert valid, 'URL prefix should only contain letter, number, underscore, and hyphen.'
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
...@@ -504,7 +508,6 @@ def create_experiment(args): ...@@ -504,7 +508,6 @@ def create_experiment(args):
config_v1 = config_yml config_v1 = config_yml
else: else:
schema = 2 schema = 2
from nni.experiment.config import convert
config_v2 = convert.to_v2(config_yml).json() config_v2 = convert.to_v2(config_yml).json()
else: else:
config_v2 = _validate_v2(config_yml, config_path) config_v2 = _validate_v2(config_yml, config_path)
......
...@@ -6,7 +6,6 @@ import os ...@@ -6,7 +6,6 @@ import os
import sys import sys
import json import json
import time import time
import re
import shutil import shutil
import subprocess import subprocess
from functools import cmp_to_key from functools import cmp_to_key
...@@ -528,7 +527,7 @@ def experiment_clean(args): ...@@ -528,7 +527,7 @@ def experiment_clean(args):
print_warning('platform {0} clean up not supported yet.'.format(platform)) print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0) exit(0)
# clean local data # clean local data
local_base_dir = experiments_config[experiment_id]['logDir'] local_base_dir = experiments_config.experiments[experiment_id]['logDir']
if not local_base_dir: if not local_base_dir:
local_base_dir = NNI_HOME_DIR local_base_dir = NNI_HOME_DIR
local_experiment_dir = os.path.join(local_base_dir, experiment_id) local_experiment_dir = os.path.join(local_base_dir, experiment_id)
......
...@@ -24,12 +24,12 @@ TENSORBOARD_API = '/tensorboard' ...@@ -24,12 +24,12 @@ TENSORBOARD_API = '/tensorboard'
METRIC_DATA_API = '/metric-data' METRIC_DATA_API = '/metric-data'
def formatURLPath(path): def format_url_path(path):
return API_ROOT_URL if path is None else '/{0}{1}'.format(path, API_ROOT_URL) return API_ROOT_URL if path is None else f'/{path}{API_ROOT_URL}'
def setPrefixUrl(prefix_path): def set_prefix_url(prefix_path):
global API_ROOT_URL global API_ROOT_URL
API_ROOT_URL = formatURLPath(prefix_path) API_ROOT_URL = format_url_path(prefix_path)
def metric_data_url(port): def metric_data_url(port):
'''get metric_data url''' '''get metric_data url'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment