archiver.py 5.53 KB
Newer Older
1
2
3
4
import datetime
import io
import json
import mmap
researcher2's avatar
researcher2 committed
5
import os
6
from pathlib import Path
Ethan Smith's avatar
Ethan Smith committed
7
from typing import Any
8

researcher2's avatar
researcher2 committed
9
import jsonlines
10
import tqdm
11
import zstandard
researcher2's avatar
researcher2 committed
12

Fabrizio Milo's avatar
Fabrizio Milo committed
13

Ethan Smith's avatar
Ethan Smith committed
14
def json_serial(obj: Any) -> str:
researcher2's avatar
researcher2 committed
15
16
17
18
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
Fabrizio Milo's avatar
Fabrizio Milo committed
19
20
    raise TypeError("Type %s not serializable" % type(obj))

researcher2's avatar
researcher2 committed
21
22
23

# Modified version of lm_dataformat Archive for single file.
class Archive:
Ethan Smith's avatar
Ethan Smith committed
24
    def __init__(self, file_path: str, compression_level: int = 3) -> None:
researcher2's avatar
researcher2 committed
25
26
27
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
28
29
            os.makedirs(dir_name, exist_ok=True)
        self.fh = open(self.file_path, "wb")
researcher2's avatar
researcher2 committed
30
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
Fabrizio Milo's avatar
Fabrizio Milo committed
31
32
        self.compressor = self.cctx.stream_writer(self.fh)

Ethan Smith's avatar
Ethan Smith committed
33
    def add_data(self, data, meta={}) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
34
35
36
37
38
39
40
        self.compressor.write(
            json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
                "UTF-8"
            )
            + b"\n"
        )

Ethan Smith's avatar
Ethan Smith committed
41
    def commit(self) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
42
        self.compressor.flush(zstandard.FLUSH_FRAME)
researcher2's avatar
researcher2 committed
43
44
45
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
46

researcher2's avatar
researcher2 committed
47
48
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
Ethan Smith's avatar
Ethan Smith committed
49
    def __init__(self) -> None:
researcher2's avatar
researcher2 committed
50
51
        pass

Ethan Smith's avatar
Ethan Smith committed
52
53
54
55
56
57
58
    def read(
        self,
        file,
        get_meta: bool = False,
        autojoin_paragraphs: bool = True,
        para_joiner: str = "\n\n",
    ):
Fabrizio Milo's avatar
Fabrizio Milo committed
59
        with open(file, "rb") as fh:
researcher2's avatar
researcher2 committed
60
61
62
63
64
65
66
67
68
69
70
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

Fabrizio Milo's avatar
Fabrizio Milo committed
71
                text = ob["text"]
researcher2's avatar
researcher2 committed
72
73
74
75
76

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
Fabrizio Milo's avatar
Fabrizio Milo committed
77
                    yield text, (ob["meta"] if "meta" in ob else {})
researcher2's avatar
researcher2 committed
78
79
80
                else:
                    yield text

Fabrizio Milo's avatar
Fabrizio Milo committed
81

researcher2's avatar
researcher2 committed
82
class TextArchive:
Ethan Smith's avatar
Ethan Smith committed
83
    def __init__(self, file_path, mode: str = "rb+") -> None:
researcher2's avatar
researcher2 committed
84
85
86
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
87
            os.makedirs(dir_name, exist_ok=True)
researcher2's avatar
researcher2 committed
88
89
90

        if not os.path.exists(file_path):
            Path(file_path).touch()
Fabrizio Milo's avatar
Fabrizio Milo committed
91
92
93

        self.fh = open(self.file_path, mode)

Ethan Smith's avatar
Ethan Smith committed
94
    def add_data(self, data) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
95
96
        self.fh.write(data.encode("UTF-8") + b"\n")

Ethan Smith's avatar
Ethan Smith committed
97
    def commit(self) -> None:
researcher2's avatar
researcher2 committed
98
99
100
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
101

researcher2's avatar
researcher2 committed
102
class TextReader:
Ethan Smith's avatar
Ethan Smith committed
103
    def __init__(self, file_path) -> None:
researcher2's avatar
researcher2 committed
104
105
        self.file_path = file_path

106
    # Optimized mmap read with infrequent tqdm updates to maintain speed
Fabrizio Milo's avatar
Fabrizio Milo committed
107
    # Tested up to 250MB/s.
Ethan Smith's avatar
Ethan Smith committed
108
    def read_tqdm(self, update_frequency: int = 10000):
109
110
        current_file_position = 0
        line_counter = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
111
112
113
114
115
116
        with open(self.file_path, "r") as fh, tqdm.tqdm(
            total=os.path.getsize(self.file_path),
            dynamic_ncols=True,
            unit="byte",
            unit_scale=1,
        ) as progress:
117
118
119
120
121
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
Fabrizio Milo's avatar
Fabrizio Milo committed
122
                        new_file_pos = mmap_obj.tell()
123
124
125
126
127
128
129
130
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
131
        with open(self.file_path, "r", encoding="utf8") as fh:
132
133
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
134
135
                    line = line.decode("utf-8")
                    new_file_pos = mmap_obj.tell()
136
137
138
139
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
140
    def read(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
141
        with open(self.file_path, "r", encoding="utf8") as fh:
142
143
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
144
                    line = line.decode("utf-8")
145
146
147
                    yield line[:-1]

    def read_slow(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
148
        with open(self.file_path, "r", encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
149
            while True:
150
                line = fh.readline()
researcher2's avatar
researcher2 committed
151
152
                if line == -1 or line == "":
                    break
153
154
155
                else:
                    yield line[:-1]

Fabrizio Milo's avatar
Fabrizio Milo committed
156

157
158
159
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
Ethan Smith's avatar
Ethan Smith committed
160
    def __init__(self, file) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
161
        self.file = file
162

Fabrizio Milo's avatar
Fabrizio Milo committed
163
    def read_tqdm(self):
164
165
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
166
        os.system(f"zstd -d {self.file}")  # linux decompress is faster
167
168
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
Fabrizio Milo's avatar
Fabrizio Milo committed
169
        os.remove(decompressed_file)