Unverified Commit 9b8df883 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Gracefully exit when timeout (#383)

* Gracefully exit when timeout, add corresponding log and return code.
* Set minimum timeout to 1 minute and enlarge Ansible timeout.
parent ec16d425
......@@ -3,6 +3,8 @@
"""Module of the base class."""
import signal
import traceback
import argparse
import numbers
from datetime import datetime
......@@ -153,25 +155,40 @@ def run(self):
True if run benchmark successfully.
"""
ret = True
self._start_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
try:
ret &= self._preprocess()
if ret:
self._start_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
signal.signal(signal.SIGTERM, self.__signal_handler)
for self._curr_run_index in range(self._args.run_count):
ret &= self._benchmark()
self._end_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self._result.set_timestamp(self._start_time, self._end_time)
if ret:
ret &= self.__check_result_format()
except TimeoutError as e:
self._result.set_return_code(ReturnCode.KILLED_BY_TIMEOUT)
logger.error('Run benchmark failed - benchmark: %s, message: %s', self._name, e)
except BaseException as e:
self._result.set_return_code(ReturnCode.RUNTIME_EXCEPTION_ERROR)
logger.error('Run benchmark failed - benchmark: {}, message: {}'.format(self._name, str(e)))
finally:
self._end_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self._result.set_timestamp(self._start_time, self._end_time)
ret &= self._postprocess()
return ret
def __signal_handler(self, signum, frame):
"""Signal handler for benchmark.
Args:
signum (int): Signal number.
frame (FrameType): Timeout frame.
"""
logger.debug('Killed by %s', signal.Signals(signum).name)
logger.debug(''.join(traceback.format_stack(frame, 5)))
if signum == signal.SIGTERM:
raise TimeoutError('Killed by SIGTERM or timeout!')
def __check_result_format(self):
"""Check the validation of result object.
......
......@@ -14,6 +14,7 @@ class ReturnCode(Enum):
INVALID_BENCHMARK_TYPE = 2
INVALID_BENCHMARK_RESULT = 3
RUNTIME_EXCEPTION_ERROR = 4
KILLED_BY_TIMEOUT = 124
# Return codes related to model benchmarks.
NO_SUPPORTED_PRECISION = 10
DISTRIBUTED_SETTING_INIT_FAILURE = 13
......
......@@ -388,6 +388,9 @@ def _run_proc(self, benchmark_name, mode, vars):
logger.info('Runner is going to run %s in %s mode, proc rank %d.', benchmark_name, mode.name, mode.proc_rank)
timeout = self._sb_benchmarks[benchmark_name].timeout
if isinstance(timeout, int):
timeout = min(timeout, 60)
env_list = '--env-file /tmp/sb.env'
if self._docker_config.skip:
env_list = 'set -o allexport && source /tmp/sb.env && set +o allexport'
......@@ -405,7 +408,9 @@ def _run_proc(self, benchmark_name, mode, vars):
if mode.name == 'mpi' and mode.node_num != 1:
ansible_runner_config = self._ansible_client.update_mpi_config(ansible_runner_config)
ansible_runner_config['timeout'] = timeout
if isinstance(timeout, int):
# we do not expect timeout in ansible unless subprocess hangs
ansible_runner_config['timeout'] = timeout + 300
rc = self._ansible_client.run(ansible_runner_config, sudo=(not self._docker_config.skip))
return rc
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""SuperBench benchmark base test."""
import os
import time
import signal
import unittest
from multiprocessing import Process, Queue
from superbench.benchmarks import BenchmarkType, ReturnCode
from superbench.benchmarks.base import Benchmark
class FooBenchmark(Benchmark):
"""Foobar benchmark for test.
Args:
Benchmark (Benchmark): Base Benchmark class.
"""
def _benchmark(self):
"""Implement _benchmark method.
Returns:
bool: True if run benchmark successfully.
"""
time.sleep(2)
return True
def test_run(self, pid_queue, rc_queue):
"""Method to test benchmark run.
Args:
pid_queue (Queue): Multiprocessing queue to share pid.
rc_queue (Queue): Multiprocessing queue to share return code.
"""
pid_queue.put(os.getpid())
self.run()
rc_queue.put(self.return_code)
class BenchmarkBaseTestCase(unittest.TestCase):
"""A class for benchmark base test cases."""
def setUp(self):
"""Hook method for setting up the test fixture before exercising it."""
self.benchmark = FooBenchmark('foo')
self.benchmark._benchmark_type = BenchmarkType.MICRO
self.pid_queue = Queue()
self.rc_queue = Queue()
def test_signal_handler(self):
"""Test signal handler when running benchmarks."""
test_cases = [
{
'signal': None,
'return_code': ReturnCode.SUCCESS,
},
{
'signal': signal.SIGTERM,
'return_code': ReturnCode.KILLED_BY_TIMEOUT,
},
]
for test_case in test_cases:
with self.subTest(msg='Testing with case', test_case=test_case):
proc = Process(target=self.benchmark.test_run, args=(
self.pid_queue,
self.rc_queue,
))
proc.start()
proc_pid = self.pid_queue.get(block=True, timeout=3)
if test_case['signal']:
killer = Process(target=os.kill, args=(proc_pid, test_case['signal']))
killer.start()
killer.join()
proc.join()
self.assertEqual(self.rc_queue.get(block=True, timeout=3), test_case['return_code'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment