Commit 2d6ae1a5 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

orchestration: fix multiple interrupt signals producing exception

parent 1ff4e4c9
...@@ -91,11 +91,20 @@ class Runtime(metaclass=ABCMeta): ...@@ -91,11 +91,20 @@ class Runtime(metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def interrupt(self): def interrupt_handler(self) -> None:
""" """
Signals an interrupt request. Interrupts signal handler.
As a consequence all currently running simulators should be stopped All currently running simulators should be stopped cleanly and their
cleanly and their output collected. output collected.
""" """
self._interrupted = True pass
def interrupt(self) -> None:
"""signals interrupt to runtime."""
# don't invoke interrupt handler multiple times as this would trigger
# repeated CancelledError
if not self._interrupted:
self._interrupted = True
self.interrupt_handler()
...@@ -77,7 +77,7 @@ class DistributedSimpleRuntime(Runtime): ...@@ -77,7 +77,7 @@ class DistributedSimpleRuntime(Runtime):
) )
run.output.dump(run.outpath) run.output.dump(run.outpath)
async def start(self): async def start(self) -> None:
for run in self.runnable: for run in self.runnable:
if self._interrupted: if self._interrupted:
return return
...@@ -85,8 +85,7 @@ class DistributedSimpleRuntime(Runtime): ...@@ -85,8 +85,7 @@ class DistributedSimpleRuntime(Runtime):
self._running = asyncio.create_task(self.do_run(run)) self._running = asyncio.create_task(self.do_run(run))
await self._running await self._running
def interrupt(self): def interrupt_handler(self) -> None:
super().interrupt()
if self._running: if self._running:
self._running.cancel() self._running.cancel()
......
...@@ -59,7 +59,7 @@ class LocalSimpleRuntime(Runtime): ...@@ -59,7 +59,7 @@ class LocalSimpleRuntime(Runtime):
# simulators yet # simulators yet
return return
run.output = await runner.run() # already handles CancelledError run.output = await runner.run() # handles CancelledError
self.complete.append(run) self.complete.append(run)
# if the log is huge, this step takes some time # if the log is huge, this step takes some time
...@@ -78,8 +78,7 @@ class LocalSimpleRuntime(Runtime): ...@@ -78,8 +78,7 @@ class LocalSimpleRuntime(Runtime):
self._running = asyncio.create_task(self.do_run(run)) self._running = asyncio.create_task(self.do_run(run))
await self._running await self._running
def interrupt(self): def interrupt_handler(self):
super().interrupt()
if self._running: if self._running:
self._running.cancel() self._running.cancel()
...@@ -210,7 +209,7 @@ class LocalParallelRuntime(Runtime): ...@@ -210,7 +209,7 @@ class LocalParallelRuntime(Runtime):
# wait for all runs to finish # wait for all runs to finish
await asyncio.wait(self._pending_jobs) await asyncio.wait(self._pending_jobs)
async def start(self): async def start(self) -> None:
"""Execute all defined runs.""" """Execute all defined runs."""
self._starter_task = asyncio.create_task(self.do_start()) self._starter_task = asyncio.create_task(self.do_start())
try: try:
...@@ -221,6 +220,5 @@ class LocalParallelRuntime(Runtime): ...@@ -221,6 +220,5 @@ class LocalParallelRuntime(Runtime):
# wait for all runs to finish # wait for all runs to finish
await asyncio.wait(self._pending_jobs) await asyncio.wait(self._pending_jobs)
def interrupt(self): def interrupt_handler(self) -> None:
super().interrupt()
self._starter_task.cancel() self._starter_task.cancel()
...@@ -117,7 +117,7 @@ class SlurmRuntime(Runtime): ...@@ -117,7 +117,7 @@ class SlurmRuntime(Runtime):
raise RuntimeError('cannot retrieve id of submitted job') raise RuntimeError('cannot retrieve id of submitted job')
run.job_id = int(m.group(1)) run.job_id = int(m.group(1))
async def start(self): async def start(self) -> None:
self._start_task = asyncio.create_task(self._do_start()) self._start_task = asyncio.create_task(self._do_start())
try: try:
await self._start_task await self._start_task
...@@ -134,6 +134,5 @@ class SlurmRuntime(Runtime): ...@@ -134,6 +134,5 @@ class SlurmRuntime(Runtime):
) )
await scancel_process.wait() await scancel_process.wait()
def interrupt(self): def interrupt_handler(self) -> None:
super().interrupt()
self._start_task.cancel() self._start_task.cancel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment