archiver.py 5.65 KB
Newer Older
1
2
3
4
import datetime
import io
import json
import mmap
researcher2's avatar
researcher2 committed
5
import os
6
from pathlib import Path
Ethan Smith's avatar
Ethan Smith committed
7
from typing import Any
8

researcher2's avatar
researcher2 committed
9
import jsonlines
10
import tqdm
11
import zstandard
researcher2's avatar
researcher2 committed
12

Fabrizio Milo's avatar
Fabrizio Milo committed
13

Ethan Smith's avatar
Ethan Smith committed
14
def json_serial(obj: Any) -> str:
researcher2's avatar
researcher2 committed
15
16
17
18
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
Fabrizio Milo's avatar
Fabrizio Milo committed
19
20
    raise TypeError("Type %s not serializable" % type(obj))

researcher2's avatar
researcher2 committed
21
22
23

# Modified version of lm_dataformat Archive for single file.
class Archive:
Ethan Smith's avatar
Ethan Smith committed
24
    def __init__(self, file_path: str, compression_level: int = 3) -> None:
researcher2's avatar
researcher2 committed
25
26
27
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
28
29
            os.makedirs(dir_name, exist_ok=True)
        self.fh = open(self.file_path, "wb")
researcher2's avatar
researcher2 committed
30
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
Fabrizio Milo's avatar
Fabrizio Milo committed
31
32
        self.compressor = self.cctx.stream_writer(self.fh)

33
34
35
    def add_data(self, data, meta=None) -> None:
        if meta is None:
            meta = {}
Fabrizio Milo's avatar
Fabrizio Milo committed
36
37
38
39
40
41
42
        self.compressor.write(
            json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
                "UTF-8"
            )
            + b"\n"
        )

Ethan Smith's avatar
Ethan Smith committed
43
    def commit(self) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
44
        self.compressor.flush(zstandard.FLUSH_FRAME)
researcher2's avatar
researcher2 committed
45
46
47
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
48

researcher2's avatar
researcher2 committed
49
50
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
Ethan Smith's avatar
Ethan Smith committed
51
    def __init__(self) -> None:
researcher2's avatar
researcher2 committed
52
53
        pass

Ethan Smith's avatar
Ethan Smith committed
54
55
56
57
58
59
60
    def read(
        self,
        file,
        get_meta: bool = False,
        autojoin_paragraphs: bool = True,
        para_joiner: str = "\n\n",
    ):
Fabrizio Milo's avatar
Fabrizio Milo committed
61
        with open(file, "rb") as fh:
researcher2's avatar
researcher2 committed
62
63
64
65
66
67
68
69
70
71
72
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

Fabrizio Milo's avatar
Fabrizio Milo committed
73
                text = ob["text"]
researcher2's avatar
researcher2 committed
74
75
76
77
78

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
Fabrizio Milo's avatar
Fabrizio Milo committed
79
                    yield text, (ob["meta"] if "meta" in ob else {})
researcher2's avatar
researcher2 committed
80
81
82
                else:
                    yield text

Fabrizio Milo's avatar
Fabrizio Milo committed
83

researcher2's avatar
researcher2 committed
84
class TextArchive:
Ethan Smith's avatar
Ethan Smith committed
85
    def __init__(self, file_path, mode: str = "rb+") -> None:
researcher2's avatar
researcher2 committed
86
87
88
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
89
            os.makedirs(dir_name, exist_ok=True)
researcher2's avatar
researcher2 committed
90
91
92

        if not os.path.exists(file_path):
            Path(file_path).touch()
Fabrizio Milo's avatar
Fabrizio Milo committed
93
94
95

        self.fh = open(self.file_path, mode)

Ethan Smith's avatar
Ethan Smith committed
96
    def add_data(self, data) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
97
98
        self.fh.write(data.encode("UTF-8") + b"\n")

Ethan Smith's avatar
Ethan Smith committed
99
    def commit(self) -> None:
researcher2's avatar
researcher2 committed
100
101
102
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
103

researcher2's avatar
researcher2 committed
104
class TextReader:
Ethan Smith's avatar
Ethan Smith committed
105
    def __init__(self, file_path) -> None:
researcher2's avatar
researcher2 committed
106
107
        self.file_path = file_path

108
    # Optimized mmap read with infrequent tqdm updates to maintain speed
Fabrizio Milo's avatar
Fabrizio Milo committed
109
    # Tested up to 250MB/s.
Ethan Smith's avatar
Ethan Smith committed
110
    def read_tqdm(self, update_frequency: int = 10000):
111
112
        current_file_position = 0
        line_counter = 0
Baber Abbasi's avatar
Baber Abbasi committed
113
114
115
116
117
118
119
120
121
        with (
            open(self.file_path, "r", encoding="utf-8") as fh,
            tqdm.tqdm(
                total=os.path.getsize(self.file_path),
                dynamic_ncols=True,
                unit="byte",
                unit_scale=1,
            ) as progress,
        ):
122
123
124
125
126
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
Fabrizio Milo's avatar
Fabrizio Milo committed
127
                        new_file_pos = mmap_obj.tell()
128
129
130
131
132
133
134
135
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
136
        with open(self.file_path, "r", encoding="utf8") as fh:
137
138
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
139
140
                    line = line.decode("utf-8")
                    new_file_pos = mmap_obj.tell()
141
142
143
144
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
145
    def read(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
146
        with open(self.file_path, "r", encoding="utf8") as fh:
147
148
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
149
                    line = line.decode("utf-8")
150
151
152
                    yield line[:-1]

    def read_slow(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
153
        with open(self.file_path, "r", encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
154
            while True:
155
                line = fh.readline()
researcher2's avatar
researcher2 committed
156
157
                if line == -1 or line == "":
                    break
158
159
160
                else:
                    yield line[:-1]

Fabrizio Milo's avatar
Fabrizio Milo committed
161

162
163
164
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
Ethan Smith's avatar
Ethan Smith committed
165
    def __init__(self, file) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
166
        self.file = file
167

Fabrizio Milo's avatar
Fabrizio Milo committed
168
    def read_tqdm(self):
169
170
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
171
        os.system(f"zstd -d {self.file}")  # linux decompress is faster
172
173
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
Fabrizio Milo's avatar
Fabrizio Milo committed
174
        os.remove(decompressed_file)