Unverified Commit 21b48d29 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support showing version check error message in WebUI (#922)

parent 0330333c
...@@ -41,6 +41,10 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -41,6 +41,10 @@ export abstract class ClusterJobRestServer extends RestServer{
private readonly expId: string = getExperimentId(); private readonly expId: string = getExperimentId();
private enableVersionCheck: boolean = true; //switch to enable version check
private versionCheckSuccess: boolean | undefined;
private errorMessage?: string;
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
*/ */
...@@ -58,6 +62,14 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -58,6 +62,14 @@ export abstract class ClusterJobRestServer extends RestServer{
} }
return this.port; return this.port;
} }
public get getErrorMessage(): string | undefined{
return this.errorMessage;
}
public set setEnableVersionCheck(versionCheck: boolean) {
this.enableVersionCheck = versionCheck;
}
/** /**
* NNIRestServer's own router registration * NNIRestServer's own router registration
...@@ -77,6 +89,31 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -77,6 +89,31 @@ export abstract class ClusterJobRestServer extends RestServer{
next(); next();
}); });
router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
if (this.enableVersionCheck) {
try {
const checkResultSuccess: boolean = req.body.tag === 'VCSuccess'? true: false;
if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
this.errorMessage = 'Version check error, version check result is inconsistent!';
this.log.error(this.errorMessage);
} else if (checkResultSuccess) {
this.log.info(`Version check in trialKeeper success!`);
this.versionCheckSuccess = true;
} else {
this.versionCheckSuccess = false;
this.errorMessage = req.body.msg;
}
} catch(err) {
this.log.error(`json parse metrics error: ${err}`);
res.status(500);
res.send(err.message);
}
} else {
this.log.info(`Skipping version check!`);
}
res.send();
});
router.post(`/update-metrics/${this.expId}/:trialId`, (req: Request, res: Response) => { router.post(`/update-metrics/${this.expId}/:trialId`, (req: Request, res: Response) => {
try { try {
this.log.info(`Get update-metrics request, trial job id is ${req.params.trialId}`); this.log.info(`Get update-metrics request, trial job id is ${req.params.trialId}`);
...@@ -94,6 +131,10 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -94,6 +131,10 @@ export abstract class ClusterJobRestServer extends RestServer{
}); });
router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => { router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
if(this.enableVersionCheck && !this.versionCheckSuccess && !this.errorMessage) {
this.errorMessage = `Version check failed, didn't get version check response from trialKeeper, please check your NNI version in `
+ `NNIManager and TrialKeeper!`
}
const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`); const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`);
try { try {
let skipLogging: boolean = false; let skipLogging: boolean = false;
......
...@@ -66,11 +66,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -66,11 +66,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
throw new Error('kubernetesJobRestServer not initialized!'); throw new Error('kubernetesJobRestServer not initialized!');
} }
await this.kubernetesJobRestServer.start(); await this.kubernetesJobRestServer.start();
this.kubernetesJobRestServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`frameworkcontroller Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`); this.log.info(`frameworkcontroller Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`);
while (!this.stopping) { while (!this.stopping) {
// collect metrics for frameworkcontroller jobs by interacting with Kubernetes API server // collect metrics for frameworkcontroller jobs by interacting with Kubernetes API server
await delay(3000); await delay(3000);
await this.fcJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient); await this.fcJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient);
if(this.kubernetesJobRestServer.getErrorMessage) {
throw new Error(this.kubernetesJobRestServer.getErrorMessage);
this.stopping = true;
}
} }
} }
......
...@@ -71,11 +71,16 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -71,11 +71,16 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
throw new Error('kubernetesJobRestServer not initialized!'); throw new Error('kubernetesJobRestServer not initialized!');
} }
await this.kubernetesJobRestServer.start(); await this.kubernetesJobRestServer.start();
this.kubernetesJobRestServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`Kubeflow Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`); this.log.info(`Kubeflow Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`);
while (!this.stopping) { while (!this.stopping) {
// collect metrics for Kubeflow jobs by interacting with Kubernetes API server // collect metrics for Kubeflow jobs by interacting with Kubernetes API server
await delay(3000); await delay(3000);
await this.kubeflowJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient); await this.kubeflowJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient);
if(this.kubernetesJobRestServer.getErrorMessage) {
throw new Error(this.kubernetesJobRestServer.getErrorMessage);
this.stopping = true;
}
} }
this.log.info('Kubeflow training service exit.'); this.log.info('Kubeflow training service exit.');
} }
......
...@@ -71,5 +71,5 @@ mkdir -p $NNI_OUTPUT_DIR ...@@ -71,5 +71,5 @@ mkdir -p $NNI_OUTPUT_DIR
cp -rT $NNI_CODE_DIR $NNI_SYS_DIR cp -rT $NNI_CODE_DIR $NNI_SYS_DIR
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip {9} --nnimanager_port {10} --version '{11}' --log_collection '{12}'` python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip {9} --nnimanager_port {10} --nni_manager_version '{11}' --log_collection '{12}'`
+ `1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr` + `1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`
...@@ -61,7 +61,7 @@ abstract class KubernetesTrainingService { ...@@ -61,7 +61,7 @@ abstract class KubernetesTrainingService {
protected kubernetesCRDClient?: KubernetesCRDClient; protected kubernetesCRDClient?: KubernetesCRDClient;
protected kubernetesJobRestServer?: KubernetesJobRestServer; protected kubernetesJobRestServer?: KubernetesJobRestServer;
protected kubernetesClusterConfig?: KubernetesClusterConfig; protected kubernetesClusterConfig?: KubernetesClusterConfig;
protected versionCheck?: boolean = true; protected versionCheck: boolean = true;
protected logCollection: string; protected logCollection: string;
constructor() { constructor() {
......
...@@ -64,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string = ...@@ -64,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}'
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --version '{12}' --log_collection '{13}'`; --pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --nni_manager_version '{12}' --log_collection '{13}'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
......
...@@ -75,7 +75,7 @@ class PAITrainingService implements TrainingService { ...@@ -75,7 +75,7 @@ class PAITrainingService implements TrainingService {
private paiRestServerPort?: number; private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
private versionCheck?: boolean = true; private versionCheck: boolean = true;
private logCollection: string; private logCollection: string;
constructor() { constructor() {
...@@ -97,11 +97,15 @@ class PAITrainingService implements TrainingService { ...@@ -97,11 +97,15 @@ class PAITrainingService implements TrainingService {
this.log.info('Run PAI training service.'); this.log.info('Run PAI training service.');
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
await restServer.start(); await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`PAI Training service rest server listening on: ${restServer.endPoint}`); this.log.info(`PAI Training service rest server listening on: ${restServer.endPoint}`);
while (!this.stopping) { while (!this.stopping) {
await this.updatePaiToken(); await this.updatePaiToken();
await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig); await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
if (restServer.getErrorMessage) {
throw new Error(restServer.getErrorMessage)
this.stopping = true;
}
await delay(3000); await delay(3000);
} }
this.log.info('PAI training service exit.'); this.log.info('PAI training service exit.');
......
...@@ -250,7 +250,7 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={ ...@@ -250,7 +250,7 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
echo $$ >{6} echo $$ >{6}
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' --version '{10}' --log_collection '{11}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' --nni_manager_version '{10}' --log_collection '{11}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{12}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
......
...@@ -102,6 +102,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -102,6 +102,7 @@ class RemoteMachineTrainingService implements TrainingService {
public async run(): Promise<void> { public async run(): Promise<void> {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer); const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
await restServer.start(); await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info('Run remote machine training service.'); this.log.info('Run remote machine training service.');
while (!this.stopping) { while (!this.stopping) {
while (this.jobQueue.length > 0) { while (this.jobQueue.length > 0) {
...@@ -117,6 +118,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -117,6 +118,10 @@ class RemoteMachineTrainingService implements TrainingService {
break; break;
} }
} }
if(restServer.getErrorMessage) {
throw new Error(restServer.getErrorMessage);
this.stopping = true;
}
await delay(3000); await delay(3000);
} }
this.log.info('Remote machine training service exit.'); this.log.info('Remote machine training service exit.');
......
...@@ -35,6 +35,7 @@ STDOUT_FULL_PATH = os.path.join(LOG_DIR, 'stdout') ...@@ -35,6 +35,7 @@ STDOUT_FULL_PATH = os.path.join(LOG_DIR, 'stdout')
STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr') STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
STDOUT_API = '/stdout' STDOUT_API = '/stdout'
VERSION_API = '/version'
NNI_SYS_DIR = os.environ['NNI_SYS_DIR'] NNI_SYS_DIR = os.environ['NNI_SYS_DIR']
NNI_TRIAL_JOB_ID = os.environ['NNI_TRIAL_JOB_ID'] NNI_TRIAL_JOB_ID = os.environ['NNI_TRIAL_JOB_ID']
NNI_EXP_ID = os.environ['NNI_EXP_ID'] NNI_EXP_ID = os.environ['NNI_EXP_ID']
\ No newline at end of file
...@@ -27,14 +27,18 @@ import shlex ...@@ -27,14 +27,18 @@ import shlex
import re import re
import sys import sys
import select import select
import json
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
import pkg_resources import pkg_resources
from .rest_utils import rest_post
from .url_utils import gen_send_stdout_url, gen_send_version_url
from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal
from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType
logger = logging.getLogger('trial_keeper') logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
def main_loop(args): def main_loop(args):
'''main loop logic for trial keeper''' '''main loop logic for trial keeper'''
...@@ -110,21 +114,27 @@ def check_version(args): ...@@ -110,21 +114,27 @@ def check_version(args):
#package nni does not exist, try nni-tool package #package nni does not exist, try nni-tool package
nni_log(LogType.Error, 'Package nni does not exist!') nni_log(LogType.Error, 'Package nni does not exist!')
os._exit(1) os._exit(1)
if not args.version: if not args.nni_manager_version:
# skip version check # skip version check
nni_log(LogType.Warning, 'Skipping version check!') nni_log(LogType.Warning, 'Skipping version check!')
else: else:
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,2}).*')
try: try:
trial_keeper_version = regular.search(trial_keeper_version).group('version') trial_keeper_version = regular.search(trial_keeper_version).group('version')
nni_log(LogType.Info, 'trial_keeper_version is {0}'.format(trial_keeper_version)) nni_log(LogType.Info, 'trial_keeper_version is {0}'.format(trial_keeper_version))
training_service_version = regular.search(args.version).group('version') nni_manager_version = regular.search(args.nni_manager_version).group('version')
nni_log(LogType.Info, 'training_service_version is {0}'.format(training_service_version)) nni_log(LogType.Info, 'nni_manager_version is {0}'.format(nni_manager_version))
if trial_keeper_version != training_service_version: log_entry = {}
if trial_keeper_version != nni_manager_version:
nni_log(LogType.Error, 'Version does not match!') nni_log(LogType.Error, 'Version does not match!')
error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(nni_manager_version, trial_keeper_version)
log_entry['tag'] = 'VCFail'
log_entry['msg'] = error_message
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False)
os._exit(1) os._exit(1)
else: else:
nni_log(LogType.Info, 'Version match!') nni_log(LogType.Info, 'Version match!')
log_entry['tag'] = 'VCSuccess'
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False)
except AttributeError as err: except AttributeError as err:
nni_log(LogType.Error, err) nni_log(LogType.Error, err)
...@@ -142,7 +152,7 @@ if __name__ == '__main__': ...@@ -142,7 +152,7 @@ if __name__ == '__main__':
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs') PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL') PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL')
PARSER.add_argument('--version', type=str, help='the nni version transmitted from trainingService') PARSER.add_argument('--nni_manager_version', type=str, help='the nni version transmitted from nniManager')
PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trialkeeper') PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trialkeeper')
args, unknown = PARSER.parse_known_args() args, unknown = PARSER.parse_known_args()
if args.trial_command is None: if args.trial_command is None:
......
...@@ -18,8 +18,12 @@ ...@@ -18,8 +18,12 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from .constants import API_ROOT_URL, BASE_URL, STDOUT_API, NNI_TRIAL_JOB_ID, NNI_EXP_ID from .constants import API_ROOT_URL, BASE_URL, STDOUT_API, NNI_TRIAL_JOB_ID, NNI_EXP_ID, VERSION_API
def gen_send_stdout_url(ip, port): def gen_send_stdout_url(ip, port):
'''Generate send stdout url''' '''Generate send stdout url'''
return '{0}:{1}{2}{3}/{4}/{5}'.format(BASE_URL.format(ip), port, API_ROOT_URL, STDOUT_API, NNI_EXP_ID, NNI_TRIAL_JOB_ID) return '{0}:{1}{2}{3}/{4}/{5}'.format(BASE_URL.format(ip), port, API_ROOT_URL, STDOUT_API, NNI_EXP_ID, NNI_TRIAL_JOB_ID)
\ No newline at end of file
def gen_send_version_url(ip, port):
'''Generate send error url'''
return '{0}:{1}{2}{3}/{4}/{5}'.format(BASE_URL.format(ip), port, API_ROOT_URL, VERSION_API, NNI_EXP_ID, NNI_TRIAL_JOB_ID)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment