Commit ec3ba41e authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

orchestration: change asyncio.wait() to asyncio.gather()

This change is mostly about the default semantics of asyncio.wait(). It
doesn't forward any exceptions that are raised by any of the coroutines
it waits for. This effectively causes exceptions to be swallowed. In
contrast, asyncio.gather() forwards exceptions to its calling coroutine.
parent f1f4cc1c
......@@ -94,7 +94,7 @@ class Component(object):
self._read_stream(self._proc.stderr, self._consume_err)
)
rc = await self._proc.wait()
await asyncio.wait([stdout_handler, stderr_handler])
await asyncio.gather(stdout_handler, stderr_handler)
await self.terminated(rc)
async def send_input(self, bs: bytes, eof=False) -> None:
......@@ -352,7 +352,7 @@ class Executor(abc.ABC):
waiter = asyncio.create_task(self.await_file(p, *args, **kwargs))
xs.append(waiter)
await asyncio.wait(xs)
await asyncio.gather(*xs)
class LocalExecutor(Executor):
......
......@@ -128,7 +128,7 @@ class ExperimentBaseRunner(ABC):
executor = self.sim_executor(host)
task = asyncio.create_task(executor.send_file(path, self.verbose))
copies.append(task)
await asyncio.wait(copies)
await asyncio.gather(*copies)
# prepare all simulators in parallel
sims = []
......@@ -141,7 +141,7 @@ class ExperimentBaseRunner(ABC):
)
)
sims.append(task)
await asyncio.wait(sims)
await asyncio.gather(*sims)
async def wait_for_sims(self) -> None:
"""Wait for simulators to terminate (the ones marked to wait on)."""
......@@ -173,7 +173,7 @@ class ExperimentBaseRunner(ABC):
for (executor, sock) in self.sockets:
scs.append(asyncio.create_task(executor.rmtree(sock)))
if scs:
await asyncio.wait(scs)
await asyncio.gather(*scs)
# add all simulator components to the output
for sim, sc in self.running:
......@@ -197,7 +197,7 @@ class ExperimentBaseRunner(ABC):
sims.append(sim)
# wait for starts to complete
await asyncio.wait(starting)
await asyncio.gather(*starting)
for sim in sims:
ts.done(sim)
......
......@@ -207,7 +207,7 @@ class LocalParallelRuntime(Runtime):
self._pending_jobs.add(job)
# wait for all runs to finish
await asyncio.wait(self._pending_jobs)
await asyncio.gather(*self._pending_jobs)
async def start(self) -> None:
"""Execute all defined runs."""
......@@ -218,7 +218,7 @@ class LocalParallelRuntime(Runtime):
for job in self._pending_jobs:
job.cancel()
# wait for all runs to finish
await asyncio.wait(self._pending_jobs)
await asyncio.gather(*self._pending_jobs)
def interrupt_handler(self) -> None:
self._starter_task.cancel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment