retry_stream.py 2.49 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from types import TracebackType
from typing import BinaryIO, Optional, Type

from megatron.energon.epathlib import EPath


class RetryReadStream(BinaryIO):
    """A stream that retries reading from a file. Only supports reading bytes."""

    _path: EPath
    _file: Optional[BinaryIO]
    _pos: int
    _size: int

    def __init__(self, path: EPath):
        """Construct a RetryReadStream. It reads only bytes from a file."""
        self._path = path
        self._file = None
        self._pos = 0
        self._size = path.size()

    def __enter__(self) -> "RetryReadStream":
        return self

    def __exit__(
        self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType
    ) -> None:
        self.close()

    def close(self) -> None:
        if self._file is not None:
            self._file.close()

    def read(self, n: int = -1) -> bytes:
        buf = b""
        for retry in range(10):
            try:
                if self._file is None:
                    self._file = self._path.open("rb")
                    self._file.seek(self._pos)
                res = self._file.read(n)
                self._pos += len(res)
                buf += res
                if (
                    (n == -1 and self._pos >= self._size)
                    or len(buf) == n
                    or self._pos >= self._size
                ):
                    return res
            except IOError:
                try:
                    self._file.close()
                except IOError:
                    pass
                self._file = None
                if retry == 9:
                    raise
                continue

    def seek(self, offset: int, whence: int = 0) -> int:
        if whence == 0:
            pass
        elif whence == 1:
            offset += self._pos
        elif whence == 2:
            offset += self._size
        else:
            raise ValueError(f"Invalid whence value: {whence}")
        offset = min(max(offset, 0), self._size)
        self._pos = offset
        try:
            if self._file is not None:
                self._file.seek(offset)
        except IOError:
            pass
        return self._pos

    def tell(self) -> int:
        return self._pos

    def isatty(self) -> bool:
        return False

    def readable(self) -> bool:
        return True

    def seekable(self) -> bool:
        return True

    def writable(self) -> bool:
        return False