Commit 741ca616 authored by Hejing Li's avatar Hejing Li
Browse files

Merge remote-tracking branch 'origin/main' into homa

parents c7438747 ec3ba41e
...@@ -94,7 +94,7 @@ class Component(object): ...@@ -94,7 +94,7 @@ class Component(object):
self._read_stream(self._proc.stderr, self._consume_err) self._read_stream(self._proc.stderr, self._consume_err)
) )
rc = await self._proc.wait() rc = await self._proc.wait()
await asyncio.wait([stdout_handler, stderr_handler]) await asyncio.gather(stdout_handler, stderr_handler)
await self.terminated(rc) await self.terminated(rc)
async def send_input(self, bs: bytes, eof=False) -> None: async def send_input(self, bs: bytes, eof=False) -> None:
...@@ -148,7 +148,8 @@ class Component(object): ...@@ -148,7 +148,8 @@ class Component(object):
try: try:
await asyncio.wait_for(self._proc.wait(), delay) await asyncio.wait_for(self._proc.wait(), delay)
return return
except TimeoutError: # before Python 3.11, asyncio.wait_for() throws asyncio.TimeoutError -_-
except (TimeoutError, asyncio.TimeoutError):
print( print(
f'terminating component {self.cmd_parts[0]} ' f'terminating component {self.cmd_parts[0]} '
f'pid {self._proc.pid}', f'pid {self._proc.pid}',
...@@ -159,14 +160,13 @@ class Component(object): ...@@ -159,14 +160,13 @@ class Component(object):
try: try:
await asyncio.wait_for(self._proc.wait(), delay) await asyncio.wait_for(self._proc.wait(), delay)
return return
except TimeoutError: except (TimeoutError, asyncio.TimeoutError):
print( print(
f'killing component {self.cmd_parts[0]} ' f'killing component {self.cmd_parts[0]} '
f'pid {self._proc.pid}', f'pid {self._proc.pid}',
flush=True flush=True
) )
await self.kill() await self.kill()
await self._proc.wait() await self._proc.wait()
async def started(self) -> None: async def started(self) -> None:
...@@ -352,7 +352,7 @@ class Executor(abc.ABC): ...@@ -352,7 +352,7 @@ class Executor(abc.ABC):
waiter = asyncio.create_task(self.await_file(p, *args, **kwargs)) waiter = asyncio.create_task(self.await_file(p, *args, **kwargs))
xs.append(waiter) xs.append(waiter)
await asyncio.wait(xs) await asyncio.gather(*xs)
class LocalExecutor(Executor): class LocalExecutor(Executor):
......
...@@ -128,7 +128,7 @@ class ExperimentBaseRunner(ABC): ...@@ -128,7 +128,7 @@ class ExperimentBaseRunner(ABC):
executor = self.sim_executor(host) executor = self.sim_executor(host)
task = asyncio.create_task(executor.send_file(path, self.verbose)) task = asyncio.create_task(executor.send_file(path, self.verbose))
copies.append(task) copies.append(task)
await asyncio.wait(copies) await asyncio.gather(*copies)
# prepare all simulators in parallel # prepare all simulators in parallel
sims = [] sims = []
...@@ -141,7 +141,7 @@ class ExperimentBaseRunner(ABC): ...@@ -141,7 +141,7 @@ class ExperimentBaseRunner(ABC):
) )
) )
sims.append(task) sims.append(task)
await asyncio.wait(sims) await asyncio.gather(*sims)
async def wait_for_sims(self) -> None: async def wait_for_sims(self) -> None:
"""Wait for simulators to terminate (the ones marked to wait on).""" """Wait for simulators to terminate (the ones marked to wait on)."""
...@@ -162,24 +162,24 @@ class ExperimentBaseRunner(ABC): ...@@ -162,24 +162,24 @@ class ExperimentBaseRunner(ABC):
scs = [] scs = []
for _, sc in self.running: for _, sc in self.running:
scs.append(asyncio.create_task(sc.int_term_kill())) scs.append(asyncio.create_task(sc.int_term_kill()))
await asyncio.shield(asyncio.wait(scs)) await asyncio.gather(*scs)
# wait for all processes to terminate # wait for all processes to terminate
for _, sc in self.running: for _, sc in self.running:
await asyncio.shield(sc.wait()) await sc.wait()
# remove all sockets # remove all sockets
scs = [] scs = []
for (executor, sock) in self.sockets: for (executor, sock) in self.sockets:
scs.append(asyncio.create_task(executor.rmtree(sock))) scs.append(asyncio.create_task(executor.rmtree(sock)))
if scs: if scs:
await asyncio.shield(asyncio.wait(scs)) await asyncio.gather(*scs)
# add all simulator components to the output # add all simulator components to the output
for sim, sc in self.running: for sim, sc in self.running:
self.out.add_sim(sim, sc) self.out.add_sim(sim, sc)
await asyncio.shield(self.after_cleanup()) await self.after_cleanup()
return self.out return self.out
async def run(self) -> ExpOutput: async def run(self) -> ExpOutput:
...@@ -197,7 +197,7 @@ class ExperimentBaseRunner(ABC): ...@@ -197,7 +197,7 @@ class ExperimentBaseRunner(ABC):
sims.append(sim) sims.append(sim)
# wait for starts to complete # wait for starts to complete
await asyncio.wait(starting) await asyncio.gather(*starting)
for sim in sims: for sim in sims:
ts.done(sim) ts.done(sim)
...@@ -222,7 +222,8 @@ class ExperimentBaseRunner(ABC): ...@@ -222,7 +222,8 @@ class ExperimentBaseRunner(ABC):
while True: while True:
try: try:
return await asyncio.shield(terminate_collect_task) return await asyncio.shield(terminate_collect_task)
except asyncio.CancelledError: except asyncio.CancelledError as e:
print(e)
pass pass
......
...@@ -207,7 +207,7 @@ class LocalParallelRuntime(Runtime): ...@@ -207,7 +207,7 @@ class LocalParallelRuntime(Runtime):
self._pending_jobs.add(job) self._pending_jobs.add(job)
# wait for all runs to finish # wait for all runs to finish
await asyncio.wait(self._pending_jobs) await asyncio.gather(*self._pending_jobs)
async def start(self) -> None: async def start(self) -> None:
"""Execute all defined runs.""" """Execute all defined runs."""
...@@ -218,7 +218,7 @@ class LocalParallelRuntime(Runtime): ...@@ -218,7 +218,7 @@ class LocalParallelRuntime(Runtime):
for job in self._pending_jobs: for job in self._pending_jobs:
job.cancel() job.cancel()
# wait for all runs to finish # wait for all runs to finish
await asyncio.wait(self._pending_jobs) await asyncio.gather(*self._pending_jobs)
def interrupt_handler(self) -> None: def interrupt_handler(self) -> None:
self._starter_task.cancel() self._starter_task.cancel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment