Commit 8b8280b7 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

implement Ctrl+C handling for LocalSimpleRuntime

parent d3f566f7
......@@ -21,6 +21,7 @@
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import asyncio
import fnmatch
import importlib
import importlib.util
......@@ -29,6 +30,7 @@ import os
import pickle
import sys
import typing as tp
from signal import SIGINT, signal
from simbricks.exectools import LocalExecutor, RemoteExecutor
from simbricks.experiment.experiment_environment import ExpEnv
......@@ -355,4 +357,13 @@ else:
with open(path, 'rb') as f:
rt.add_run(pickle.load(f))
rt.start()
# register interrupt handler
def handle_interrupt(signalnum: int, handler: int):
# pylint: disable=unused-argument
rt.interrupt()
signal(SIGINT, handle_interrupt)
asyncio.run(rt.start())
......@@ -47,7 +47,7 @@ class HostConfig(object):
class Component(object):
def __init__(self, cmd_parts, with_stdin=False):
def __init__(self, cmd_parts: tp.List[str], with_stdin=False):
self.is_ready = False
self.stdout = []
self.stdout_buf = bytearray()
......@@ -57,8 +57,8 @@ class Component(object):
#print(cmd_parts)
self.with_stdin = with_stdin
self._proc: tp.Optional[Process] = None
self._terminate_future: tp.Optional[asyncio.Task[int]] = None
self._proc: Process
self._terminate_future: asyncio.Task
def _parse_buf(self, buf, data):
if data is not None:
......@@ -76,21 +76,21 @@ class Component(object):
lines.append(buf.decode('utf-8'))
return lines
async def _consume_out(self, data):
async def _consume_out(self, data: bytes):
eof = len(data) == 0
ls = self._parse_buf(self.stdout_buf, data)
if len(ls) > 0 or eof:
await self.process_out(ls, eof=eof)
self.stdout = self.stdout + ls
async def _consume_err(self, data):
async def _consume_err(self, data: bytes):
eof = len(data) == 0
ls = self._parse_buf(self.stderr_buf, data)
if len(ls) > 0 or eof:
await self.process_err(ls, eof=eof)
self.stderr = self.stderr + ls
async def _read_stream(self, stream, fn):
async def _read_stream(self, stream: asyncio.StreamReader, fn):
while True:
bs = await stream.readline()
if bs:
......@@ -109,7 +109,6 @@ class Component(object):
rc = await self._proc.wait()
await out_handlers
await self.terminated(rc)
return rc
async def send_input(self, bs, eof=False):
self._proc.stdin.write(bs)
......@@ -128,36 +127,52 @@ class Component(object):
stderr=asyncio.subprocess.PIPE,
stdin=stdin,
)
self._terminate_future = asyncio.ensure_future(self._waiter())
self._terminate_future = asyncio.create_task(self._waiter())
await self.started()
async def wait(self):
await self._terminate_future
"""
Wait for running process to finish and output to be collected.
On cancellation, the `CancelledError` is propagated but this component
keeps running.
"""
await asyncio.shield(self._terminate_future)
async def interrupt(self):
if self._terminate_future.done():
return
"""Sends an interrupt signal."""
if self._proc.returncode is None:
self._proc.send_signal(signal.SIGINT)
async def terminate(self):
if self._terminate_future.done():
return
"""Sends a terminate signal."""
if self._proc.returncode is None:
self._proc.terminate()
async def kill(self):
"""Sends a kill signal."""
if self._proc.returncode is None:
self._proc.kill()
async def int_term_kill(self, delay=5):
"""Attempts to stop this component by sending signals in the following
order: interrupt, terminate, kill."""
await self.interrupt()
_, pending = await asyncio.wait([self._terminate_future], timeout=delay)
_, pending = await asyncio.wait([self._proc.wait()], timeout=delay)
if len(pending) != 0:
print('terminating')
print(
f'terminating component {self.cmd_parts[0]} '
f'pid {self._proc.pid}'
)
await self.terminate()
_,pending = await asyncio.wait([self._terminate_future],
timeout=delay)
_, pending = await asyncio.wait([self._proc.wait()], timeout=delay)
if len(pending) != 0:
print('killing')
print(
f'killing component {self.cmd_parts[0]} '
f'pid {self._proc.pid}'
)
await self.kill()
await self._proc.wait()
async def started(self):
pass
......@@ -248,7 +263,7 @@ class SimpleRemoteComponent(SimpleComponent):
] + self.extra_flags + [self.host_name, '--'] + parts
async def start(self):
"""Start this command (includes waiting for its pid."""
"""Start this command (includes waiting for its pid)."""
self._pid_fut = asyncio.get_running_loop().create_future()
await super().start()
await self._pid_fut
......
......@@ -36,6 +36,7 @@ class ExpOutput(object):
self.end_time = None
self.sims = {}
self.success = True
self.interrupted = False
def set_start(self):
self.start_time = time.time()
......@@ -46,6 +47,10 @@ class ExpOutput(object):
def set_failed(self):
self.success = False
def set_interrupted(self):
self.success = False
self.interrupted = True
def add_sim(self, sim, comp):
obj = {
'class': sim.__class__.__name__,
......
......@@ -27,7 +27,7 @@ import traceback
import typing as tp
from abc import ABC, abstractmethod
from simbricks.exectools import Executor, SimpleComponent
from simbricks.exectools import Component, Executor, SimpleComponent
from simbricks.experiment.experiment_environment import ExpEnv
from simbricks.experiment.experiment_output import ExpOutput
from simbricks.experiments import DistributedExperiment, Experiment
......@@ -44,7 +44,7 @@ class ExperimentBaseRunner(ABC):
self.out = ExpOutput(exp)
self.running: tp.List[tp.Tuple[Simulator, SimpleComponent]] = []
self.sockets = []
self.wait_sims = []
self.wait_sims: tp.List[Component] = []
@abstractmethod
def sim_executor(self, sim: Simulator) -> Executor:
......@@ -166,6 +166,10 @@ class ExperimentBaseRunner(ABC):
await self.before_wait()
await self.wait_for_sims()
except asyncio.CancelledError:
if self.verbose:
print(f'{self.exp.name}: interrupted')
self.out.set_interrupted()
except: # pylint: disable=bare-except
self.out.set_failed()
traceback.print_exc()
......
......@@ -76,10 +76,24 @@ class Run(object):
class Runtime(metaclass=ABCMeta):
"""Base class for managing the execution of multiple runs."""
def __init__(self) -> None:
self._interrupted = False
"""Indicates whether interrupt has been signaled."""
@abstractmethod
def add_run(self, run: Run):
pass
@abstractmethod
def start(self):
async def start(self):
pass
@abstractmethod
def interrupt(self):
"""
Signals an interrupt request.
As a consequence all currently running simulators should be stopped
cleanly and their output collected.
"""
self._interrupted = True
......@@ -34,6 +34,7 @@ from simbricks.runtime.common import Run, Runtime
class DistributedSimpleRuntime(Runtime):
def __init__(self, executors, verbose=False):
super().__init__()
self.runnable: tp.List[Run] = []
self.complete: tp.List[Run] = []
self.verbose = verbose
......@@ -63,10 +64,14 @@ class DistributedSimpleRuntime(Runtime):
with open(run.outpath, 'w', encoding='utf-8') as f:
f.write(run.output.dumps())
def start(self):
async def start(self):
for run in self.runnable:
asyncio.run(self.do_run(run))
def interrupt(self):
return super().interrupt()
# TODO implement this
def auto_dist(
e: Experiment, execs: tp.List[Executor], proxy_type: str = 'sockets'
......
......@@ -37,32 +37,49 @@ class LocalSimpleRuntime(Runtime):
verbose=False,
executor: exectools.Executor = exectools.LocalExecutor()
):
super().__init__()
self.runnable: tp.List[Run] = []
self.complete: tp.List[Run] = []
self.verbose = verbose
self.executor = executor
self._running: tp.Optional[asyncio.Task] = None
def add_run(self, run: Run):
self.runnable.append(run)
async def do_run(self, run: Run):
"""Actually executes `run`."""
try:
runner = ExperimentSimpleRunner(
self.executor, run.experiment, run.env, self.verbose
)
await run.prep_dirs(self.executor)
await runner.prepare()
run.output = await runner.run()
except asyncio.CancelledError:
# it is safe to just exit here because we are not running any
# simulators yet
return
run.output = await runner.run() # already handles CancelledError
self.complete.append(run)
pathlib.Path(run.outpath).parent.mkdir(parents=True, exist_ok=True)
with open(run.outpath, 'w', encoding='utf-8') as f:
f.write(run.output.dumps())
def start(self):
async def start(self):
"""Execute the runs defined in `self.runnable`."""
for run in self.runnable:
asyncio.run(self.do_run(run))
if self._interrupted:
return
self._running = asyncio.create_task(self.do_run(run))
await self._running
def interrupt(self):
super().interrupt()
if self._running:
self._running.cancel()
class LocalParallelRuntime(Runtime):
......@@ -75,6 +92,7 @@ class LocalParallelRuntime(Runtime):
verbose=False,
executor: exectools.Executor = exectools.LocalExecutor()
):
super().__init__()
self.runs_noprereq: tp.List[Run] = []
"""Runs with no prerequesite runs."""
self.runs_prereq: tp.List[Run] = []
......@@ -180,6 +198,10 @@ class LocalParallelRuntime(Runtime):
while self.pending_jobs:
await self.wait_completion()
def start(self):
async def start(self):
"""Execute all defined runs."""
asyncio.run(self.do_start())
def interrupt(self):
return super().interrupt()
# TODO implement this
......@@ -31,6 +31,7 @@ from simbricks.runtime.common import Run, Runtime
class SlurmRuntime(Runtime):
def __init__(self, slurmdir, args, verbose=False, cleanup=True):
super().__init__()
self.runnable = []
self.slurmdir = slurmdir
self.args = args
......@@ -87,7 +88,7 @@ class SlurmRuntime(Runtime):
return exp_script
def start(self):
async def start(self):
pathlib.Path(self.slurmdir).mkdir(parents=True, exist_ok=True)
jid_re = re.compile(r'Submitted batch job ([0-9]+)')
......@@ -109,3 +110,7 @@ class SlurmRuntime(Runtime):
m = jid_re.search(output)
run.job_id = int(m.group(1))
def interrupt(self):
return super().interrupt()
# TODO implement this
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment