Commit f9d8870f authored by Kai Chen's avatar Kai Chen
Browse files

auto-connect pavi service before running

parent 95cb8853
...@@ -6,10 +6,10 @@ from .optimizer_stepper import OptimizerHook ...@@ -6,10 +6,10 @@ from .optimizer_stepper import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
pavi_hook_connect, TensorboardLoggerHook) TensorboardLoggerHook)
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
'PaviLoggerHook', 'pavi_hook_connect', 'TensorboardLoggerHook' 'PaviLoggerHook', 'TensorboardLoggerHook'
] ]
from .base import LoggerHook from .base import LoggerHook
from .pavi import PaviLoggerHook, pavi_hook_connect from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook from .text import TextLoggerHook
__all__ = [ __all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'pavi_hook_connect', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
'TensorboardLoggerHook'
] ]
from __future__ import print_function from __future__ import print_function
import logging
import os import os
import os.path as osp
import time import time
from datetime import datetime from datetime import datetime
from threading import Thread from threading import Thread
...@@ -20,6 +22,7 @@ class PaviClient(object): ...@@ -20,6 +22,7 @@ class PaviClient(object):
self.password = self._get_env_var(password, 'PAVI_PASSWORD') self.password = self._get_env_var(password, 'PAVI_PASSWORD')
self.instance_id = instance_id self.instance_id = instance_id
self.log_queue = None self.log_queue = None
self.logger = None
def _get_env_var(self, var, env_var): def _get_env_var(self, var, env_var):
if var is not None: if var is not None:
...@@ -32,25 +35,28 @@ class PaviClient(object): ...@@ -32,25 +35,28 @@ class PaviClient(object):
format(env_var)) format(env_var))
return var return var
def _print_log(self, msg, level=logging.INFO, *args, **kwargs):
if self.logger is not None:
self.logger.log(level, msg, *args, **kwargs)
else:
print(msg, *args, **kwargs)
def connect(self, def connect(self,
model_name, model_name,
work_dir=None, work_dir=None,
info=dict(), info=dict(),
timeout=5, timeout=5,
logger=None): logger=None):
if logger: if logger is not None:
log_info = logger.info self.logger = logger
log_error = logger.error self._print_log('connecting pavi service {}...'.format(self.url))
else:
log_info = log_error = print
log_info('connecting pavi service {}...'.format(self.url))
post_data = dict( post_data = dict(
time=str(datetime.now()), time=str(datetime.now()),
username=self.username, username=self.username,
password=self.password, password=self.password,
instance_id=self.instance_id, instance_id=self.instance_id,
model=model_name, model=model_name,
work_dir=os.path.abspath(work_dir) if work_dir else '', work_dir=osp.abspath(work_dir) if work_dir else '',
session_file=info.get('session_file', ''), session_file=info.get('session_file', ''),
session_text=info.get('session_text', ''), session_text=info.get('session_text', ''),
model_text=info.get('model_text', ''), model_text=info.get('model_text', ''),
...@@ -58,11 +64,14 @@ class PaviClient(object): ...@@ -58,11 +64,14 @@ class PaviClient(object):
try: try:
response = requests.post(self.url, json=post_data, timeout=timeout) response = requests.post(self.url, json=post_data, timeout=timeout)
except Exception as ex: except Exception as ex:
log_error('fail to connect to pavi service: {}'.format(ex)) self._print_log(
'fail to connect to pavi service: {}'.format(ex),
level=logging.ERROR)
else: else:
if response.status_code == 200: if response.status_code == 200:
self.instance_id = response.text self.instance_id = response.text
log_info('pavi service connected, instance_id: {}'.format( self._print_log(
'pavi service connected, instance_id: {}'.format(
self.instance_id)) self.instance_id))
self.log_queue = Queue() self.log_queue = Queue()
self.log_thread = Thread(target=self.post_worker_fn) self.log_thread = Thread(target=self.post_worker_fn)
...@@ -70,9 +79,11 @@ class PaviClient(object): ...@@ -70,9 +79,11 @@ class PaviClient(object):
self.log_thread.start() self.log_thread.start()
return True return True
else: else:
log_error('fail to connect to pavi service, status code: ' self._print_log(
'fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code, '{}, err message: {}'.format(response.status_code,
response.reason)) response.reason),
level=logging.ERROR)
return False return False
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3): def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
...@@ -82,7 +93,9 @@ class PaviClient(object): ...@@ -82,7 +93,9 @@ class PaviClient(object):
except Empty: except Empty:
time.sleep(1) time.sleep(1)
except Exception as ex: except Exception as ex:
print('fail to get logs from queue: {}'.format(ex)) self._print_log(
'fail to get logs from queue: {}'.format(ex),
level=logging.ERROR)
else: else:
retry = 0 retry = 0
while retry < max_retry: while retry < max_retry:
...@@ -91,17 +104,24 @@ class PaviClient(object): ...@@ -91,17 +104,24 @@ class PaviClient(object):
self.url, json=log, timeout=req_timeout) self.url, json=log, timeout=req_timeout)
except Exception as ex: except Exception as ex:
retry += 1 retry += 1
print('error when posting logs to pavi: {}'.format(ex)) self._print_log(
'error when posting logs to pavi: {}'.format(ex),
level=logging.ERROR)
else: else:
status_code = response.status_code status_code = response.status_code
if status_code == 200: if status_code == 200:
break break
else: else:
print('unexpected status code: %d, err msg: %s', self._print_log(
status_code, response.reason) 'unexpected status code: %d, err msg: {}'.
format(status_code, response.reason),
level=logging.ERROR)
retry += 1 retry += 1
if retry == max_retry: if retry == max_retry:
print('fail to send logs of iteration %d', log['iter_num']) self._print_log(
'fail to send logs of iteration {}'.format(
log['iter_num']),
level=logging.ERROR)
def log(self, phase, iter, outputs): def log(self, phase, iter, outputs):
if self.log_queue is not None: if self.log_queue is not None:
...@@ -123,21 +143,29 @@ class PaviLoggerHook(LoggerHook): ...@@ -123,21 +143,29 @@ class PaviLoggerHook(LoggerHook):
username=None, username=None,
password=None, password=None,
instance_id=None, instance_id=None,
config_file=None,
interval=10, interval=10,
reset_meter=True, reset_meter=True,
ignore_last=True): ignore_last=True):
self.pavi = PaviClient(url, username, password, instance_id) self.pavi = PaviClient(url, username, password, instance_id)
self.config_file = config_file
super(PaviLoggerHook, self).__init__(interval, reset_meter, super(PaviLoggerHook, self).__init__(interval, reset_meter,
ignore_last) ignore_last)
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)
self.connect(runner)
@master_only @master_only
def connect(self, def connect(self, runner, timeout=5):
model_name, cfg_info = dict()
work_dir=None, if self.config_file is not None:
info=dict(), with open(self.config_file, 'r') as f:
timeout=5, config_text = f.read()
logger=None): cfg_info.update(
return self.pavi.connect(model_name, work_dir, info, timeout, logger) session_file=self.config_file, session_text=config_text)
return self.pavi.connect(runner.model_name, runner.work_dir, cfg_info,
timeout, runner.logger)
@master_only @master_only
def log(self, runner): def log(self, runner):
...@@ -145,17 +173,3 @@ class PaviLoggerHook(LoggerHook): ...@@ -145,17 +173,3 @@ class PaviLoggerHook(LoggerHook):
log_outs.pop('time', None) log_outs.pop('time', None)
log_outs.pop('data_time', None) log_outs.pop('data_time', None)
self.pavi.log(runner.mode, runner.iter, log_outs) self.pavi.log(runner.mode, runner.iter, log_outs)
def pavi_hook_connect(runner, cfg_filename, cfg_text):
for hook in runner.hooks:
if isinstance(hook, PaviLoggerHook):
hook.connect(
runner.model_name,
runner.work_dir,
info={
'session_file': cfg_filename,
'session_text': cfg_text
},
logger=runner.logger)
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment