archiver.py 5.78 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
# /// script
# requires-python = ">=3.8"
# dependencies = [
#     "jsonlines",
#     "mmap",
#     "tqdm",
#     "zstandard",
# ]
# ///

# ruff: noqa
12
13
14
15
import datetime
import io
import json
import mmap
researcher2's avatar
researcher2 committed
16
import os
17
from pathlib import Path
Ethan Smith's avatar
Ethan Smith committed
18
from typing import Any
19

researcher2's avatar
researcher2 committed
20
import jsonlines
21
import tqdm
22
import zstandard
researcher2's avatar
researcher2 committed
23

Fabrizio Milo's avatar
Fabrizio Milo committed
24

Ethan Smith's avatar
Ethan Smith committed
25
def json_serial(obj: Any) -> str:
researcher2's avatar
researcher2 committed
26
27
28
29
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
Fabrizio Milo's avatar
Fabrizio Milo committed
30
31
    raise TypeError("Type %s not serializable" % type(obj))

researcher2's avatar
researcher2 committed
32
33
34

# Modified version of lm_dataformat Archive for single file.
class Archive:
Ethan Smith's avatar
Ethan Smith committed
35
    def __init__(self, file_path: str, compression_level: int = 3) -> None:
researcher2's avatar
researcher2 committed
36
37
38
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
39
40
            os.makedirs(dir_name, exist_ok=True)
        self.fh = open(self.file_path, "wb")
researcher2's avatar
researcher2 committed
41
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
Fabrizio Milo's avatar
Fabrizio Milo committed
42
43
        self.compressor = self.cctx.stream_writer(self.fh)

44
45
46
    def add_data(self, data, meta=None) -> None:
        if meta is None:
            meta = {}
Fabrizio Milo's avatar
Fabrizio Milo committed
47
48
49
50
51
52
53
        self.compressor.write(
            json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
                "UTF-8"
            )
            + b"\n"
        )

Ethan Smith's avatar
Ethan Smith committed
54
    def commit(self) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
55
        self.compressor.flush(zstandard.FLUSH_FRAME)
researcher2's avatar
researcher2 committed
56
57
58
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
59

researcher2's avatar
researcher2 committed
60
61
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
Ethan Smith's avatar
Ethan Smith committed
62
    def __init__(self) -> None:
researcher2's avatar
researcher2 committed
63
64
        pass

Ethan Smith's avatar
Ethan Smith committed
65
66
67
68
69
70
71
    def read(
        self,
        file,
        get_meta: bool = False,
        autojoin_paragraphs: bool = True,
        para_joiner: str = "\n\n",
    ):
Fabrizio Milo's avatar
Fabrizio Milo committed
72
        with open(file, "rb") as fh:
researcher2's avatar
researcher2 committed
73
74
75
76
77
78
79
80
81
82
83
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

Fabrizio Milo's avatar
Fabrizio Milo committed
84
                text = ob["text"]
researcher2's avatar
researcher2 committed
85
86
87
88
89

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
Fabrizio Milo's avatar
Fabrizio Milo committed
90
                    yield text, (ob["meta"] if "meta" in ob else {})
researcher2's avatar
researcher2 committed
91
92
93
                else:
                    yield text

Fabrizio Milo's avatar
Fabrizio Milo committed
94

researcher2's avatar
researcher2 committed
95
class TextArchive:
Ethan Smith's avatar
Ethan Smith committed
96
    def __init__(self, file_path, mode: str = "rb+") -> None:
researcher2's avatar
researcher2 committed
97
98
99
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
100
            os.makedirs(dir_name, exist_ok=True)
researcher2's avatar
researcher2 committed
101
102
103

        if not os.path.exists(file_path):
            Path(file_path).touch()
Fabrizio Milo's avatar
Fabrizio Milo committed
104
105
106

        self.fh = open(self.file_path, mode)

Ethan Smith's avatar
Ethan Smith committed
107
    def add_data(self, data) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
108
109
        self.fh.write(data.encode("UTF-8") + b"\n")

Ethan Smith's avatar
Ethan Smith committed
110
    def commit(self) -> None:
researcher2's avatar
researcher2 committed
111
112
113
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
114

researcher2's avatar
researcher2 committed
115
class TextReader:
Ethan Smith's avatar
Ethan Smith committed
116
    def __init__(self, file_path) -> None:
researcher2's avatar
researcher2 committed
117
118
        self.file_path = file_path

119
    # Optimized mmap read with infrequent tqdm updates to maintain speed
Fabrizio Milo's avatar
Fabrizio Milo committed
120
    # Tested up to 250MB/s.
Ethan Smith's avatar
Ethan Smith committed
121
    def read_tqdm(self, update_frequency: int = 10000):
122
123
        current_file_position = 0
        line_counter = 0
Baber Abbasi's avatar
Baber Abbasi committed
124
        with (
Baber's avatar
Baber committed
125
            open(self.file_path, encoding="utf-8") as fh,
Baber Abbasi's avatar
Baber Abbasi committed
126
127
128
129
130
131
132
            tqdm.tqdm(
                total=os.path.getsize(self.file_path),
                dynamic_ncols=True,
                unit="byte",
                unit_scale=1,
            ) as progress,
        ):
133
134
135
136
137
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
Fabrizio Milo's avatar
Fabrizio Milo committed
138
                        new_file_pos = mmap_obj.tell()
139
140
141
142
143
144
145
146
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
Baber's avatar
Baber committed
147
        with open(self.file_path, encoding="utf8") as fh:
148
149
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
150
151
                    line = line.decode("utf-8")
                    new_file_pos = mmap_obj.tell()
152
153
154
155
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
156
    def read(self):
Baber's avatar
Baber committed
157
        with open(self.file_path, encoding="utf8") as fh:
158
159
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
160
                    line = line.decode("utf-8")
161
162
163
                    yield line[:-1]

    def read_slow(self):
Baber's avatar
Baber committed
164
        with open(self.file_path, encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
165
            while True:
166
                line = fh.readline()
researcher2's avatar
researcher2 committed
167
168
                if line == -1 or line == "":
                    break
169
170
171
                else:
                    yield line[:-1]

Fabrizio Milo's avatar
Fabrizio Milo committed
172

173
174
175
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
Ethan Smith's avatar
Ethan Smith committed
176
    def __init__(self, file) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
177
        self.file = file
178

Fabrizio Milo's avatar
Fabrizio Milo committed
179
    def read_tqdm(self):
180
181
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
182
        os.system(f"zstd -d {self.file}")  # linux decompress is faster
183
184
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
Fabrizio Milo's avatar
Fabrizio Milo committed
185
        os.remove(decompressed_file)