Commit 8b8280b7 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

implement Ctrl+C handling for LocalSimpleRuntime

parent d3f566f7
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse import argparse
import asyncio
import fnmatch import fnmatch
import importlib import importlib
import importlib.util import importlib.util
...@@ -29,6 +30,7 @@ import os ...@@ -29,6 +30,7 @@ import os
import pickle import pickle
import sys import sys
import typing as tp import typing as tp
from signal import SIGINT, signal
from simbricks.exectools import LocalExecutor, RemoteExecutor from simbricks.exectools import LocalExecutor, RemoteExecutor
from simbricks.experiment.experiment_environment import ExpEnv from simbricks.experiment.experiment_environment import ExpEnv
...@@ -355,4 +357,13 @@ else: ...@@ -355,4 +357,13 @@ else:
with open(path, 'rb') as f: with open(path, 'rb') as f:
rt.add_run(pickle.load(f)) rt.add_run(pickle.load(f))
rt.start()
# register interrupt handler
def handle_interrupt(signalnum: int, handler: int):
# pylint: disable=unused-argument
rt.interrupt()
signal(SIGINT, handle_interrupt)
asyncio.run(rt.start())
...@@ -47,7 +47,7 @@ class HostConfig(object): ...@@ -47,7 +47,7 @@ class HostConfig(object):
class Component(object): class Component(object):
def __init__(self, cmd_parts, with_stdin=False): def __init__(self, cmd_parts: tp.List[str], with_stdin=False):
self.is_ready = False self.is_ready = False
self.stdout = [] self.stdout = []
self.stdout_buf = bytearray() self.stdout_buf = bytearray()
...@@ -57,8 +57,8 @@ class Component(object): ...@@ -57,8 +57,8 @@ class Component(object):
#print(cmd_parts) #print(cmd_parts)
self.with_stdin = with_stdin self.with_stdin = with_stdin
self._proc: tp.Optional[Process] = None self._proc: Process
self._terminate_future: tp.Optional[asyncio.Task[int]] = None self._terminate_future: asyncio.Task
def _parse_buf(self, buf, data): def _parse_buf(self, buf, data):
if data is not None: if data is not None:
...@@ -76,21 +76,21 @@ class Component(object): ...@@ -76,21 +76,21 @@ class Component(object):
lines.append(buf.decode('utf-8')) lines.append(buf.decode('utf-8'))
return lines return lines
async def _consume_out(self, data): async def _consume_out(self, data: bytes):
eof = len(data) == 0 eof = len(data) == 0
ls = self._parse_buf(self.stdout_buf, data) ls = self._parse_buf(self.stdout_buf, data)
if len(ls) > 0 or eof: if len(ls) > 0 or eof:
await self.process_out(ls, eof=eof) await self.process_out(ls, eof=eof)
self.stdout = self.stdout + ls self.stdout = self.stdout + ls
async def _consume_err(self, data): async def _consume_err(self, data: bytes):
eof = len(data) == 0 eof = len(data) == 0
ls = self._parse_buf(self.stderr_buf, data) ls = self._parse_buf(self.stderr_buf, data)
if len(ls) > 0 or eof: if len(ls) > 0 or eof:
await self.process_err(ls, eof=eof) await self.process_err(ls, eof=eof)
self.stderr = self.stderr + ls self.stderr = self.stderr + ls
async def _read_stream(self, stream, fn): async def _read_stream(self, stream: asyncio.StreamReader, fn):
while True: while True:
bs = await stream.readline() bs = await stream.readline()
if bs: if bs:
...@@ -109,7 +109,6 @@ class Component(object): ...@@ -109,7 +109,6 @@ class Component(object):
rc = await self._proc.wait() rc = await self._proc.wait()
await out_handlers await out_handlers
await self.terminated(rc) await self.terminated(rc)
return rc
async def send_input(self, bs, eof=False): async def send_input(self, bs, eof=False):
self._proc.stdin.write(bs) self._proc.stdin.write(bs)
...@@ -128,36 +127,52 @@ class Component(object): ...@@ -128,36 +127,52 @@ class Component(object):
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
stdin=stdin, stdin=stdin,
) )
self._terminate_future = asyncio.ensure_future(self._waiter()) self._terminate_future = asyncio.create_task(self._waiter())
await self.started() await self.started()
async def wait(self): async def wait(self):
await self._terminate_future """
Wait for running process to finish and output to be collected.
On cancellation, the `CancelledError` is propagated but this component
keeps running.
"""
await asyncio.shield(self._terminate_future)
async def interrupt(self): async def interrupt(self):
if self._terminate_future.done(): """Sends an interrupt signal."""
return if self._proc.returncode is None:
self._proc.send_signal(signal.SIGINT) self._proc.send_signal(signal.SIGINT)
async def terminate(self): async def terminate(self):
if self._terminate_future.done(): """Sends a terminate signal."""
return if self._proc.returncode is None:
self._proc.terminate() self._proc.terminate()
async def kill(self): async def kill(self):
self._proc.kill() """Sends a kill signal."""
if self._proc.returncode is None:
self._proc.kill()
async def int_term_kill(self, delay=5): async def int_term_kill(self, delay=5):
"""Attempts to stop this component by sending signals in the following
order: interrupt, terminate, kill."""
await self.interrupt() await self.interrupt()
_, pending = await asyncio.wait([self._terminate_future], timeout=delay) _, pending = await asyncio.wait([self._proc.wait()], timeout=delay)
if len(pending) != 0: if len(pending) != 0:
print('terminating') print(
f'terminating component {self.cmd_parts[0]} '
f'pid {self._proc.pid}'
)
await self.terminate() await self.terminate()
_,pending = await asyncio.wait([self._terminate_future], _, pending = await asyncio.wait([self._proc.wait()], timeout=delay)
timeout=delay)
if len(pending) != 0: if len(pending) != 0:
print('killing') print(
f'killing component {self.cmd_parts[0]} '
f'pid {self._proc.pid}'
)
await self.kill() await self.kill()
await self._proc.wait()
async def started(self): async def started(self):
pass pass
...@@ -248,7 +263,7 @@ class SimpleRemoteComponent(SimpleComponent): ...@@ -248,7 +263,7 @@ class SimpleRemoteComponent(SimpleComponent):
] + self.extra_flags + [self.host_name, '--'] + parts ] + self.extra_flags + [self.host_name, '--'] + parts
async def start(self): async def start(self):
"""Start this command (includes waiting for its pid.""" """Start this command (includes waiting for its pid)."""
self._pid_fut = asyncio.get_running_loop().create_future() self._pid_fut = asyncio.get_running_loop().create_future()
await super().start() await super().start()
await self._pid_fut await self._pid_fut
......
...@@ -36,6 +36,7 @@ class ExpOutput(object): ...@@ -36,6 +36,7 @@ class ExpOutput(object):
self.end_time = None self.end_time = None
self.sims = {} self.sims = {}
self.success = True self.success = True
self.interrupted = False
def set_start(self): def set_start(self):
self.start_time = time.time() self.start_time = time.time()
...@@ -46,6 +47,10 @@ class ExpOutput(object): ...@@ -46,6 +47,10 @@ class ExpOutput(object):
def set_failed(self): def set_failed(self):
self.success = False self.success = False
def set_interrupted(self):
self.success = False
self.interrupted = True
def add_sim(self, sim, comp): def add_sim(self, sim, comp):
obj = { obj = {
'class': sim.__class__.__name__, 'class': sim.__class__.__name__,
......
...@@ -27,7 +27,7 @@ import traceback ...@@ -27,7 +27,7 @@ import traceback
import typing as tp import typing as tp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from simbricks.exectools import Executor, SimpleComponent from simbricks.exectools import Component, Executor, SimpleComponent
from simbricks.experiment.experiment_environment import ExpEnv from simbricks.experiment.experiment_environment import ExpEnv
from simbricks.experiment.experiment_output import ExpOutput from simbricks.experiment.experiment_output import ExpOutput
from simbricks.experiments import DistributedExperiment, Experiment from simbricks.experiments import DistributedExperiment, Experiment
...@@ -44,7 +44,7 @@ class ExperimentBaseRunner(ABC): ...@@ -44,7 +44,7 @@ class ExperimentBaseRunner(ABC):
self.out = ExpOutput(exp) self.out = ExpOutput(exp)
self.running: tp.List[tp.Tuple[Simulator, SimpleComponent]] = [] self.running: tp.List[tp.Tuple[Simulator, SimpleComponent]] = []
self.sockets = [] self.sockets = []
self.wait_sims = [] self.wait_sims: tp.List[Component] = []
@abstractmethod @abstractmethod
def sim_executor(self, sim: Simulator) -> Executor: def sim_executor(self, sim: Simulator) -> Executor:
...@@ -166,6 +166,10 @@ class ExperimentBaseRunner(ABC): ...@@ -166,6 +166,10 @@ class ExperimentBaseRunner(ABC):
await self.before_wait() await self.before_wait()
await self.wait_for_sims() await self.wait_for_sims()
except asyncio.CancelledError:
if self.verbose:
print(f'{self.exp.name}: interrupted')
self.out.set_interrupted()
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
self.out.set_failed() self.out.set_failed()
traceback.print_exc() traceback.print_exc()
......
...@@ -76,10 +76,24 @@ class Run(object): ...@@ -76,10 +76,24 @@ class Run(object):
class Runtime(metaclass=ABCMeta): class Runtime(metaclass=ABCMeta):
"""Base class for managing the execution of multiple runs.""" """Base class for managing the execution of multiple runs."""
def __init__(self) -> None:
self._interrupted = False
"""Indicates whether interrupt has been signaled."""
@abstractmethod @abstractmethod
def add_run(self, run: Run): def add_run(self, run: Run):
pass pass
@abstractmethod @abstractmethod
def start(self): async def start(self):
pass pass
@abstractmethod
def interrupt(self):
"""
Signals an interrupt request.
As a consequence all currently running simulators should be stopped
cleanly and their output collected.
"""
self._interrupted = True
...@@ -34,6 +34,7 @@ from simbricks.runtime.common import Run, Runtime ...@@ -34,6 +34,7 @@ from simbricks.runtime.common import Run, Runtime
class DistributedSimpleRuntime(Runtime): class DistributedSimpleRuntime(Runtime):
def __init__(self, executors, verbose=False): def __init__(self, executors, verbose=False):
super().__init__()
self.runnable: tp.List[Run] = [] self.runnable: tp.List[Run] = []
self.complete: tp.List[Run] = [] self.complete: tp.List[Run] = []
self.verbose = verbose self.verbose = verbose
...@@ -63,10 +64,14 @@ class DistributedSimpleRuntime(Runtime): ...@@ -63,10 +64,14 @@ class DistributedSimpleRuntime(Runtime):
with open(run.outpath, 'w', encoding='utf-8') as f: with open(run.outpath, 'w', encoding='utf-8') as f:
f.write(run.output.dumps()) f.write(run.output.dumps())
def start(self): async def start(self):
for run in self.runnable: for run in self.runnable:
asyncio.run(self.do_run(run)) asyncio.run(self.do_run(run))
def interrupt(self):
return super().interrupt()
# TODO implement this
def auto_dist( def auto_dist(
e: Experiment, execs: tp.List[Executor], proxy_type: str = 'sockets' e: Experiment, execs: tp.List[Executor], proxy_type: str = 'sockets'
......
...@@ -37,32 +37,49 @@ class LocalSimpleRuntime(Runtime): ...@@ -37,32 +37,49 @@ class LocalSimpleRuntime(Runtime):
verbose=False, verbose=False,
executor: exectools.Executor = exectools.LocalExecutor() executor: exectools.Executor = exectools.LocalExecutor()
): ):
super().__init__()
self.runnable: tp.List[Run] = [] self.runnable: tp.List[Run] = []
self.complete: tp.List[Run] = [] self.complete: tp.List[Run] = []
self.verbose = verbose self.verbose = verbose
self.executor = executor self.executor = executor
self._running: tp.Optional[asyncio.Task] = None
def add_run(self, run: Run): def add_run(self, run: Run):
self.runnable.append(run) self.runnable.append(run)
async def do_run(self, run: Run): async def do_run(self, run: Run):
"""Actually executes `run`.""" """Actually executes `run`."""
runner = ExperimentSimpleRunner( try:
self.executor, run.experiment, run.env, self.verbose runner = ExperimentSimpleRunner(
) self.executor, run.experiment, run.env, self.verbose
await run.prep_dirs(self.executor) )
await runner.prepare() await run.prep_dirs(self.executor)
run.output = await runner.run() await runner.prepare()
except asyncio.CancelledError:
# it is safe to just exit here because we are not running any
# simulators yet
return
run.output = await runner.run() # already handles CancelledError
self.complete.append(run) self.complete.append(run)
pathlib.Path(run.outpath).parent.mkdir(parents=True, exist_ok=True) pathlib.Path(run.outpath).parent.mkdir(parents=True, exist_ok=True)
with open(run.outpath, 'w', encoding='utf-8') as f: with open(run.outpath, 'w', encoding='utf-8') as f:
f.write(run.output.dumps()) f.write(run.output.dumps())
def start(self): async def start(self):
"""Execute the runs defined in `self.runnable`.""" """Execute the runs defined in `self.runnable`."""
for run in self.runnable: for run in self.runnable:
asyncio.run(self.do_run(run)) if self._interrupted:
return
self._running = asyncio.create_task(self.do_run(run))
await self._running
def interrupt(self):
super().interrupt()
if self._running:
self._running.cancel()
class LocalParallelRuntime(Runtime): class LocalParallelRuntime(Runtime):
...@@ -75,6 +92,7 @@ class LocalParallelRuntime(Runtime): ...@@ -75,6 +92,7 @@ class LocalParallelRuntime(Runtime):
verbose=False, verbose=False,
executor: exectools.Executor = exectools.LocalExecutor() executor: exectools.Executor = exectools.LocalExecutor()
): ):
super().__init__()
self.runs_noprereq: tp.List[Run] = [] self.runs_noprereq: tp.List[Run] = []
"""Runs with no prerequesite runs.""" """Runs with no prerequesite runs."""
self.runs_prereq: tp.List[Run] = [] self.runs_prereq: tp.List[Run] = []
...@@ -180,6 +198,10 @@ class LocalParallelRuntime(Runtime): ...@@ -180,6 +198,10 @@ class LocalParallelRuntime(Runtime):
while self.pending_jobs: while self.pending_jobs:
await self.wait_completion() await self.wait_completion()
def start(self): async def start(self):
"""Execute all defined runs.""" """Execute all defined runs."""
asyncio.run(self.do_start()) asyncio.run(self.do_start())
def interrupt(self):
return super().interrupt()
# TODO implement this
...@@ -31,6 +31,7 @@ from simbricks.runtime.common import Run, Runtime ...@@ -31,6 +31,7 @@ from simbricks.runtime.common import Run, Runtime
class SlurmRuntime(Runtime): class SlurmRuntime(Runtime):
def __init__(self, slurmdir, args, verbose=False, cleanup=True): def __init__(self, slurmdir, args, verbose=False, cleanup=True):
super().__init__()
self.runnable = [] self.runnable = []
self.slurmdir = slurmdir self.slurmdir = slurmdir
self.args = args self.args = args
...@@ -87,7 +88,7 @@ class SlurmRuntime(Runtime): ...@@ -87,7 +88,7 @@ class SlurmRuntime(Runtime):
return exp_script return exp_script
def start(self): async def start(self):
pathlib.Path(self.slurmdir).mkdir(parents=True, exist_ok=True) pathlib.Path(self.slurmdir).mkdir(parents=True, exist_ok=True)
jid_re = re.compile(r'Submitted batch job ([0-9]+)') jid_re = re.compile(r'Submitted batch job ([0-9]+)')
...@@ -109,3 +110,7 @@ class SlurmRuntime(Runtime): ...@@ -109,3 +110,7 @@ class SlurmRuntime(Runtime):
m = jid_re.search(output) m = jid_re.search(output)
run.job_id = int(m.group(1)) run.job_id = int(m.group(1))
def interrupt(self):
return super().interrupt()
# TODO implement this
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment