Unverified Commit 9dfbeb41 authored by Shiyan Deng's avatar Shiyan Deng Committed by GitHub
Browse files

[RFC] allow cancelation after shutdown in blocking collective_rpc (#23390)


Signed-off-by: default avatarShiyan Deng <dsy842974287@meta.com>
parent eedb2a2a
...@@ -253,7 +253,8 @@ class MultiprocExecutor(Executor): ...@@ -253,7 +253,8 @@ class MultiprocExecutor(Executor):
if not non_block: if not non_block:
result = result.result() result = result.result()
elif not non_block: elif not non_block:
result = get_response(w, dequeue_timeout) result = get_response(w, dequeue_timeout,
self.shutdown_event)
else: else:
raise RuntimeError("non_block can only be used when" raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1") " max_concurrent_batches > 1")
...@@ -295,12 +296,8 @@ class MultiprocExecutor(Executor): ...@@ -295,12 +296,8 @@ class MultiprocExecutor(Executor):
"""Properly shut down the executor and its workers""" """Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False): if not getattr(self, 'shutting_down', False):
self.shutting_down = True self.shutting_down = True
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
# Make sure all the worker processes are terminated first.
if workers := getattr(self, 'workers', None): if workers := getattr(self, 'workers', None):
for w in workers: for w in workers:
# Close death_writer to signal child processes to exit # Close death_writer to signal child processes to exit
...@@ -310,6 +307,11 @@ class MultiprocExecutor(Executor): ...@@ -310,6 +307,11 @@ class MultiprocExecutor(Executor):
w.worker_response_mq = None w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in workers]) self._ensure_worker_termination([w.proc for w in workers])
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
del self.io_thread_pool
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
def check_health(self) -> None: def check_health(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment