Unverified Commit b7c91e73 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix bugs and lints in nnictl (#3712)


Co-authored-by: default avatarliuzhe <zhe.liu@microsoft.com>
parent 259aee75
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import importlib
import json
from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \
......
......@@ -9,11 +9,11 @@ import time
import socket
import string
import random
import yaml
import psutil
import filelock
import glob
from colorama import Fore
import filelock
import psutil
import yaml
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
......
......@@ -6,12 +6,12 @@ import logging
import os
import netifaces
from schema import And, Optional, Or, Regex, Schema, SchemaError
from nni.tools.package_utils import (
create_validator_instance,
get_all_builtin_names,
get_registered_algo_meta,
)
from schema import And, Optional, Or, Regex, Schema, SchemaError
from .common_utils import get_yml_content, print_warning
from .constants import SCHEMA_PATH_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_TYPE_ERROR
......@@ -625,7 +625,7 @@ class NNIConfigSchema:
raise SchemaError("""If no taskRoles are specified a valid custom frameworkcontroller config should
be set using the configPath attribute in frameworkcontrollerConfig!""")
config_content = get_yml_content(experiment_config.get('frameworkcontrollerConfig').get('configPath'))
if not config_content.get('spec').get('taskRoles') or not len(config_content.get('spec').get('taskRoles')):
if not config_content.get('spec').get('taskRoles') or not config_content.get('spec').get('taskRoles'):
raise SchemaError('Invalid frameworkcontroller config! No taskRoles were specified!')
if not config_content.get('spec').get('taskRoles')[0].get('task'):
raise SchemaError('Invalid frameworkcontroller config! No task was specified for taskRole!')
......
......@@ -2,12 +2,9 @@
# Licensed under the MIT license.
import os
import json_tricks
import shutil
import sqlite3
import time
import json_tricks
from .constants import NNI_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock
def config_v0_to_v1(config: dict) -> dict:
......
......@@ -17,10 +17,9 @@ from nni.tools.package_utils import get_builtin_module_class_name
import nni_node # pylint: disable=import-error
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_port, get_user
from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command
......@@ -84,7 +83,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
cmds += ['--foreground', 'true']
if url_prefix:
_validate_prefix_path(url_prefix)
setPrefixUrl(url_prefix)
set_prefix_url(url_prefix)
cmds += ['--url_prefix', url_prefix]
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
......@@ -167,7 +166,8 @@ def set_V1_common_config(experiment_config, port, config_file_name):
response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT)
validate_response(response, config_file_name)
if experiment_config.get('logCollection'):
response = rest_put(cluster_metadata_url(port), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT)
data = json.dumps({'log_collection': experiment_config.get('logCollection')})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
validate_response(response, config_file_name)
def setNNIManagerIp(experiment_config, port, config_file_name):
......@@ -229,7 +229,8 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
def set_shared_storage(experiment_config, port, config_file_name):
if 'sharedStorage' in experiment_config:
response = rest_put(cluster_metadata_url(port), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT)
data = json.dumps({'shared_storage_config': experiment_config['sharedStorage']})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -485,7 +486,10 @@ def _validate_v2(config, path):
print_error(f'Config V2 validation failed: {repr(e)}')
def _validate_prefix_path(path):
assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid."
assert not path.startswith('/'), 'URL prefix should not start with "/".'
parts = path.split('/')
valid = all(re.match('^[A-Za-z0-9_-]*$', part) for part in parts)
assert valid, 'URL prefix should only contain letter, number, underscore, and hyphen.'
def create_experiment(args):
'''start a new experiment'''
......@@ -504,7 +508,6 @@ def create_experiment(args):
config_v1 = config_yml
else:
schema = 2
from nni.experiment.config import convert
config_v2 = convert.to_v2(config_yml).json()
else:
config_v2 = _validate_v2(config_yml, config_path)
......
......@@ -6,7 +6,6 @@ import os
import sys
import json
import time
import re
import shutil
import subprocess
from functools import cmp_to_key
......@@ -528,7 +527,7 @@ def experiment_clean(args):
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
# clean local data
local_base_dir = experiments_config[experiment_id]['logDir']
local_base_dir = experiments_config.experiments[experiment_id]['logDir']
if not local_base_dir:
local_base_dir = NNI_HOME_DIR
local_experiment_dir = os.path.join(local_base_dir, experiment_id)
......
......@@ -24,12 +24,12 @@ TENSORBOARD_API = '/tensorboard'
METRIC_DATA_API = '/metric-data'
def formatURLPath(path):
return API_ROOT_URL if path is None else '/{0}{1}'.format(path, API_ROOT_URL)
def format_url_path(path):
return API_ROOT_URL if path is None else f'/{path}{API_ROOT_URL}'
def setPrefixUrl(prefix_path):
def set_prefix_url(prefix_path):
global API_ROOT_URL
API_ROOT_URL = formatURLPath(prefix_path)
API_ROOT_URL = format_url_path(prefix_path)
def metric_data_url(port):
'''get metric_data url'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment