Commit 2d6ae1a5 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

orchestration: fix multiple interrupt signals producing exception

parent 1ff4e4c9
......@@ -91,11 +91,20 @@ class Runtime(metaclass=ABCMeta):
pass
@abstractmethod
def interrupt(self):
def interrupt_handler(self) -> None:
"""
Signals an interrupt request.
Interrupts signal handler.
As a consequence all currently running simulators should be stopped
cleanly and their output collected.
All currently running simulators should be stopped cleanly and their
output collected.
"""
pass
def interrupt(self) -> None:
"""signals interrupt to runtime."""
# don't invoke interrupt handler multiple times as this would trigger
# repeated CancelledError
if not self._interrupted:
self._interrupted = True
self.interrupt_handler()
......@@ -77,7 +77,7 @@ class DistributedSimpleRuntime(Runtime):
)
run.output.dump(run.outpath)
async def start(self):
async def start(self) -> None:
for run in self.runnable:
if self._interrupted:
return
......@@ -85,8 +85,7 @@ class DistributedSimpleRuntime(Runtime):
self._running = asyncio.create_task(self.do_run(run))
await self._running
def interrupt(self):
super().interrupt()
def interrupt_handler(self) -> None:
if self._running:
self._running.cancel()
......
......@@ -59,7 +59,7 @@ class LocalSimpleRuntime(Runtime):
# simulators yet
return
run.output = await runner.run() # already handles CancelledError
run.output = await runner.run() # handles CancelledError
self.complete.append(run)
# if the log is huge, this step takes some time
......@@ -78,8 +78,7 @@ class LocalSimpleRuntime(Runtime):
self._running = asyncio.create_task(self.do_run(run))
await self._running
def interrupt(self):
super().interrupt()
def interrupt_handler(self):
if self._running:
self._running.cancel()
......@@ -210,7 +209,7 @@ class LocalParallelRuntime(Runtime):
# wait for all runs to finish
await asyncio.wait(self._pending_jobs)
async def start(self):
async def start(self) -> None:
"""Execute all defined runs."""
self._starter_task = asyncio.create_task(self.do_start())
try:
......@@ -221,6 +220,5 @@ class LocalParallelRuntime(Runtime):
# wait for all runs to finish
await asyncio.wait(self._pending_jobs)
def interrupt(self):
super().interrupt()
def interrupt_handler(self) -> None:
self._starter_task.cancel()
......@@ -117,7 +117,7 @@ class SlurmRuntime(Runtime):
raise RuntimeError('cannot retrieve id of submitted job')
run.job_id = int(m.group(1))
async def start(self):
async def start(self) -> None:
self._start_task = asyncio.create_task(self._do_start())
try:
await self._start_task
......@@ -134,6 +134,5 @@ class SlurmRuntime(Runtime):
)
await scancel_process.wait()
def interrupt(self):
super().interrupt()
def interrupt_handler(self) -> None:
self._start_task.cancel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment