non_blocking_io.py 14.7 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import concurrent.futures
import io
import logging
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import IO, Callable, Optional, Union

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/non_blocking_io.py
# --------------------------------------------------------


"""
This file is used for asynchronous file operations.
When `opena` is called for the first time for a specific
`PathHandler`, a `NonBlockingIOManager` is instantiated. The
manager returns a `NonBlockingIO` (or `NonBlockingBufferedIO`)
instance to the caller, and the manager maintains all of the
thread management and data management.
"""


@dataclass
class PathData:
    """
    Manage the IO job queue and polling thread for a single
    path. This is done to ensure that write calls to the same
    path are serialized so they are written in the same order
    as they were called.
    On each `f.write` call where `f` is of type `NonBlockingIO`,
    we send the job to the manager where it is enqueued to the
    Queue. The polling Thread picks up on the job, executes it,
    waits for it to finish, and then continues to poll.
    """

    queue: Queue
    thread: Thread


class NonBlockingIOManager:
    """
    All `opena` calls pass through this class so that it can
    keep track of the threads for proper cleanup at the end
    of the script. Each path that is opened with `opena` is
    assigned a single queue and polling thread that is kept
    open until it is cleaned up by `PathManager.async_join()`.
    """

    def __init__(
        self,
        buffered: Optional[bool] = False,
        executor: Optional[concurrent.futures.Executor] = None,
    ) -> None:
        """
        Args:
            buffered (bool): IO instances will be `NonBlockingBufferedIO`
                or `NonBlockingIO` based on this value. This bool is set
                manually for each `PathHandler` in `_opena`.
            executor: User can optionally attach a custom executor to
                perform async operations through `PathHandler.__init__`.
        """
        self._path_to_data = {}  # Map from path to `PathData` object
        self._buffered = buffered
        self._IO = NonBlockingBufferedIO if self._buffered else NonBlockingIO
        self._pool = executor or concurrent.futures.ThreadPoolExecutor()

    def get_non_blocking_io(
        self,
        path: str,
        io_obj: Union[IO[str], IO[bytes]],
        callback_after_file_close: Optional[Callable[[None], None]] = None,
        buffering: Optional[int] = -1,
    ) -> Union[IO[str], IO[bytes]]:
        """
        Called by `PathHandler._opena` with the path and returns a
        `NonBlockingIO` instance.
        Args:
            path (str): A path str to operate on. This path should be
                simplified to ensure that each absolute path has only a single
                path str that maps onto it. For example, in `NativePathHandler`,
                we can use `os.path.normpath`.
            io_obj (IO): a reference to the IO object returned by the
                `PathHandler._open` function.
            callback_after_file_close (Callable): An optional argument that can
                be passed to perform operations that depend on the asynchronous
                writes being completed. The file is first written to the local
                disk and then the callback is executed.
            buffering (int): An optional argument to set the buffer size for
                buffered asynchronous writing.
        """
        if not self._buffered and buffering != -1:
            raise ValueError(
                "NonBlockingIO is not using a buffered writer but `buffering` "
                f"arg is set to non-default value of {buffering} != -1."
            )

        if path not in self._path_to_data:
            # Initialize job queue and a polling thread
            queue = Queue()
            t = Thread(target=self._poll_jobs, args=(queue,))
            t.start()
            # Store the `PathData`
            self._path_to_data[path] = PathData(queue, t)

        kwargs = {} if not self._buffered else {"buffering": buffering}
        return self._IO(
            notify_manager=lambda io_callable: (  # Pass async jobs to manager
                self._path_to_data[path].queue.put(io_callable)
            ),
            io_obj=io_obj,
            callback_after_file_close=callback_after_file_close,
            **kwargs,
        )

    def _poll_jobs(self, queue: Optional[Callable[[], None]]) -> None:
        """
        A single thread runs this loop. It waits for an IO callable to be
        placed in a specific path's `Queue` where the queue contains
        callable functions. It then waits for the IO job to be completed
        before looping to ensure write order.
        """
        while True:
            # `func` is a callable function (specifically a lambda function)
            # and can be any of:
            #   - func = file.write(b)
            #   - func = file.close()
            #   - func = None
            func = queue.get()  # Blocks until item read.
            if func is None:  # Thread join signal.
                break
            self._pool.submit(func).result()  # Wait for job to finish.

    def _join(self, path: Optional[str] = None) -> bool:
        """
        Waits for write jobs for a specific path or waits for all
        write jobs for the path handler if no path is provided.
        Args:
            path (str): Pass in a file path and will wait for the
                asynchronous jobs to be completed for that file path.
                If no path is passed in, then all threads operating
                on all file paths will be joined.
        """
        if path and path not in self._path_to_data:
            raise ValueError(
                f"{path} has no async IO associated with it. "
                f"Make sure `opena({path})` is called first."
            )
        # If a `_close` call fails, we print the error and continue
        # closing the rest of the IO objects.
        paths_to_close = [path] if path else list(self._path_to_data.keys())
        success = True
        for _path in paths_to_close:
            try:
                path_data = self._path_to_data.pop(_path)
                path_data.queue.put(None)
                path_data.thread.join()
            except Exception:
                logger = logging.getLogger(__name__)
                logger.exception(f"`NonBlockingIO` thread for {_path} failed to join.")
                success = False
        return success

    def _close_thread_pool(self) -> bool:
        """
        Closes the ThreadPool.
        """
        try:
            self._pool.shutdown()
        except Exception:
            logger = logging.getLogger(__name__)
            logger.exception("`NonBlockingIO` thread pool failed to close.")
            return False
        return True


# NOTE: We currently only support asynchronous writes (not reads).
class NonBlockingIO(io.IOBase):
    def __init__(
        self,
        notify_manager: Callable[[Callable[[], None]], None],
        io_obj: Union[IO[str], IO[bytes]],
        callback_after_file_close: Optional[Callable[[None], None]] = None,
    ) -> None:
        """
        Returned to the user on an `opena` call. Uses a Queue to manage the
        IO jobs that need to be run to ensure order preservation and a
        polling Thread that checks the Queue. Implementation for these are
        lifted to `NonBlockingIOManager` since `NonBlockingIO` closes upon
        leaving the context block.
        NOTE: Writes to the same path are serialized so they are written in
        the same order as they were called but writes to distinct paths can
        happen concurrently.
        Args:
            notify_manager (Callable): a callback function passed in from the
                `NonBlockingIOManager` so that all IO jobs can be stored in
                the manager. It takes in a single argument, namely another
                callable function.
                Example usage:
                ```
                    notify_manager(lambda: file.write(data))
                    notify_manager(lambda: file.close())
                ```
                Here, we tell `NonBlockingIOManager` to add a write callable
                to the path's Queue, and then to add a close callable to the
                path's Queue. The path's polling Thread then executes the write
                callable, waits for it to finish, and then executes the close
                callable. Using `lambda` allows us to pass callables to the
                manager.
            io_obj (IO): a reference to the IO object returned by the
                `PathHandler._open` function.
            callback_after_file_close (Callable): An optional argument that can
                be passed to perform operations that depend on the asynchronous
                writes being completed. The file is first written to the local
                disk and then the callback is executed.
        """
        super().__init__()
        self._notify_manager = notify_manager
        self._io = io_obj
        self._callback_after_file_close = callback_after_file_close

        self._close_called = False

    def readable(self) -> bool:
        return False

    def writable(self) -> bool:
        return True

    def seekable(self) -> bool:
        return True

    def write(self, b: Union[bytes, bytearray]) -> None:
        """
        Called on `f.write()`. Gives the manager the write job to call.
        """
        self._notify_manager(lambda: self._io.write(b))

    def seek(self, offset: int, whence: int = 0) -> int:
        """
        Called on `f.seek()`.
        """
        self._notify_manager(lambda: self._io.seek(offset, whence))

    def tell(self) -> int:
        """
        Called on `f.tell()`.
        """
        raise ValueError("ioPath async writes does not support `tell` calls.")

    def truncate(self, size: int = None) -> int:
        """
        Called on `f.truncate()`.
        """
        self._notify_manager(lambda: self._io.truncate(size))

    def close(self) -> None:
        """
        Called on `f.close()` or automatically by the context manager.
        We add the `close` call to the file's queue to make sure that
        the file is not closed before all of the write jobs are complete.
        """
        # `ThreadPool` first closes the file and then executes the callback.
        # We only execute the callback once even if there are multiple
        # `f.close` calls.
        self._notify_manager(lambda: self._io.close())
        if not self._close_called and self._callback_after_file_close:
            self._notify_manager(self._callback_after_file_close)
        self._close_called = True


# NOTE: To use this class, use `buffered=True` in `NonBlockingIOManager`.
# NOTE: This class expects the IO mode to be buffered.
class NonBlockingBufferedIO(io.IOBase):
    MAX_BUFFER_BYTES = 10 * 1024 * 1024  # 10 MiB

    def __init__(
        self,
        notify_manager: Callable[[Callable[[], None]], None],
        io_obj: Union[IO[str], IO[bytes]],
        callback_after_file_close: Optional[Callable[[None], None]] = None,
        buffering: int = -1,
    ) -> None:
        """
        Buffered version of `NonBlockingIO`. All write data is stored in an
        IO buffer until the buffer is full, or `flush` or `close` is called.
        Args:
            Same as `NonBlockingIO` args.
            buffering (int): An optional argument to set the buffer size for
                buffered asynchronous writing.
        """
        super().__init__()
        self._notify_manager = notify_manager
        self._io = io_obj
        self._callback_after_file_close = callback_after_file_close

        self._buffers = [io.BytesIO()]
        self._buffer_size = buffering if buffering > 0 else self.MAX_BUFFER_BYTES
        self._close_called = False

    def readable(self) -> bool:
        return False

    def writable(self) -> bool:
        return True

    def seekable(self) -> bool:
        return False

    def write(self, b: Union[bytes, bytearray]) -> None:
        """
        Called on `f.write()`. Gives the manager the write job to call.
        """
        buffer = self._buffers[-1]
        with memoryview(b) as view:
            buffer.write(view)
        if buffer.tell() < self._buffer_size:
            return
        self.flush()

    def close(self) -> None:
        """
        Called on `f.close()` or automatically by the context manager.
        We add the `close` call to the file's queue to make sure that
        the file is not closed before all of the write jobs are complete.
        """
        self.flush()
        # Close the last buffer created by `flush`.
        self._notify_manager(lambda: self._buffers[-1].close())
        # `ThreadPool` first closes the file and then executes the callback.
        self._notify_manager(lambda: self._io.close())
        if not self._close_called and self._callback_after_file_close:
            self._notify_manager(self._callback_after_file_close)
        self._close_called = True

    def flush(self) -> None:
        """
        Called on `f.write()` if the buffer is filled (or overfilled). Can
        also be explicitly called by user.
        NOTE: Buffering is used in a strict manner. Any buffer that exceeds
        `self._buffer_size` will be broken into multiple write jobs where
        each has a write call with `self._buffer_size` size.
        """
        buffer = self._buffers[-1]
        if buffer.tell() == 0:
            return
        pos = 0
        total_size = buffer.seek(0, io.SEEK_END)
        view = buffer.getbuffer()
        # Chunk the buffer in case it is larger than the buffer size.
        while pos < total_size:
            item = view[pos : pos + self._buffer_size]
            # `item=item` is needed due to Python's late binding closures.
            self._notify_manager(lambda item=item: self._io.write(item))
            pos += self._buffer_size
        # Close buffer immediately after being written to file and create
        # a new buffer.
        self._notify_manager(lambda: buffer.close())
        self._buffers.append(io.BytesIO())