Unverified Commit fc1493a0 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[FrontEnd] Make `merge_async_iterators` `is_cancelled` arg optional (#7282)

parent 311f7438
...@@ -405,7 +405,7 @@ async def iterate_with_cancellation( ...@@ -405,7 +405,7 @@ async def iterate_with_cancellation(
async def merge_async_iterators( async def merge_async_iterators(
*iterators: AsyncGenerator[T, None], *iterators: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]], is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
) -> AsyncGenerator[Tuple[int, T], None]: ) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
...@@ -413,8 +413,8 @@ async def merge_async_iterators( ...@@ -413,8 +413,8 @@ async def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item. iterator that yields the item.
It also polls the provided function at least once per second to check It also optionally polls a provided function at least once per second
for client cancellation. to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 # Can use anext() in python >= 3.10
...@@ -422,12 +422,13 @@ async def merge_async_iterators( ...@@ -422,12 +422,13 @@ async def merge_async_iterators(
ensure_future(pair[1].__anext__()): pair ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators) for pair in enumerate(iterators)
} }
timeout = None if is_cancelled is None else 1
try: try:
while awaits: while awaits:
done, pending = await asyncio.wait(awaits.keys(), done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
timeout=1) timeout=timeout)
if await is_cancelled(): if is_cancelled is not None and await is_cancelled():
raise asyncio.CancelledError("client cancelled") raise asyncio.CancelledError("client cancelled")
for d in done: for d in done:
pair = awaits.pop(d) pair = awaits.pop(d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment