Unverified Commit 8afaa376 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Enhance timeout cleanup to avoid possible hanging (#405)

Enhance timeout cleanup to avoid possible hanging.

__Major Revisions__
* Skip postprocess (mainly torch.dist.barrier and destroy) when exception happens (e.g., timeout, GPU crashed) to avoid subprocesses hanging.
* Add cleanup to kill sb exec processes when Ansible run failed for certain benchmark.

__Minor Revisions__
* Update extra Ansible timeout from 300s to 60s.
parent db842892
......@@ -171,10 +171,11 @@ def run(self):
except BaseException as e:
self._result.set_return_code(ReturnCode.RUNTIME_EXCEPTION_ERROR)
logger.error('Run benchmark failed - benchmark: {}, message: {}'.format(self._name, str(e)))
else:
ret &= self._postprocess()
finally:
self._end_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self._result.set_timestamp(self._start_time, self._end_time)
ret &= self._postprocess()
return ret
......
- name: Runtime Environment Cleanup
hosts: all
gather_facts: false
tasks:
- name: Killing sb exec processes
shell: |
pgrep -ax sb | grep 'sb exec' | awk '{print $1}' | xargs kill -9 ||:
become: yes
......@@ -193,6 +193,10 @@ def check_env(self): # pragma: no cover
)
)
def cleanup(self): # pragma: no cover
"""Cleanup remaining processes on all nodes."""
self._ansible_client.run(self._ansible_client.get_playbook_config('cleanup.yaml'))
def fetch_results(self): # pragma: no cover
"""Fetch benchmark results on all nodes."""
try:
......@@ -410,7 +414,7 @@ def _run_proc(self, benchmark_name, mode, vars):
if isinstance(timeout, int):
# we do not expect timeout in ansible unless subprocess hangs
ansible_runner_config['timeout'] = timeout + 300
ansible_runner_config['timeout'] = timeout + 60
rc = self._ansible_client.run(ansible_runner_config, sudo=(not self._docker_config.skip))
return rc
......@@ -423,16 +427,20 @@ def run(self):
continue
benchmark_config = self._sb_benchmarks[benchmark_name]
for mode in benchmark_config.modes:
ansible_rc = 0
if mode.name == 'local':
Parallel(n_jobs=mode.proc_num if mode.parallel else 1)(
rc_list = Parallel(n_jobs=mode.proc_num if mode.parallel else 1)(
delayed(self._run_proc)(benchmark_name, mode, {
'proc_rank': proc_rank
}) for proc_rank in range(mode.proc_num)
)
ansible_rc = sum(rc_list)
elif mode.name == 'torch.distributed' or mode.name == 'mpi':
self._run_proc(benchmark_name, mode, {'proc_rank': 0})
ansible_rc = self._run_proc(benchmark_name, mode, {'proc_rank': 0})
else:
logger.warning('Unknown mode %s.', mode.name)
if ansible_rc != 0:
self.cleanup()
self.fetch_results()
self.__create_results_summary()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment