archiver.py 5.53 KB
Newer Older
researcher2's avatar
researcher2 committed
1
import os
Ethan Smith's avatar
Ethan Smith committed
2
from typing import Any
researcher2's avatar
researcher2 committed
3
4
5
6
7
import zstandard
import json
import jsonlines
import io
import datetime
8
9
import mmap
import tqdm
researcher2's avatar
researcher2 committed
10
from pathlib import Path
researcher2's avatar
researcher2 committed
11

Fabrizio Milo's avatar
Fabrizio Milo committed
12

Ethan Smith's avatar
Ethan Smith committed
13
def json_serial(obj: Any) -> str:
researcher2's avatar
researcher2 committed
14
15
16
17
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
Fabrizio Milo's avatar
Fabrizio Milo committed
18
19
    raise TypeError("Type %s not serializable" % type(obj))

researcher2's avatar
researcher2 committed
20
21
22

# Modified version of lm_dataformat Archive for single file.
class Archive:
Ethan Smith's avatar
Ethan Smith committed
23
    def __init__(self, file_path: str, compression_level: int = 3) -> None:
researcher2's avatar
researcher2 committed
24
25
26
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
27
28
            os.makedirs(dir_name, exist_ok=True)
        self.fh = open(self.file_path, "wb")
researcher2's avatar
researcher2 committed
29
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
Fabrizio Milo's avatar
Fabrizio Milo committed
30
31
        self.compressor = self.cctx.stream_writer(self.fh)

Ethan Smith's avatar
Ethan Smith committed
32
    def add_data(self, data, meta={}) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
33
34
35
36
37
38
39
        self.compressor.write(
            json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
                "UTF-8"
            )
            + b"\n"
        )

Ethan Smith's avatar
Ethan Smith committed
40
    def commit(self) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
41
        self.compressor.flush(zstandard.FLUSH_FRAME)
researcher2's avatar
researcher2 committed
42
43
44
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
45

researcher2's avatar
researcher2 committed
46
47
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
Ethan Smith's avatar
Ethan Smith committed
48
    def __init__(self) -> None:
researcher2's avatar
researcher2 committed
49
50
        pass

Ethan Smith's avatar
Ethan Smith committed
51
52
53
54
55
56
57
    def read(
        self,
        file,
        get_meta: bool = False,
        autojoin_paragraphs: bool = True,
        para_joiner: str = "\n\n",
    ):
Fabrizio Milo's avatar
Fabrizio Milo committed
58
        with open(file, "rb") as fh:
researcher2's avatar
researcher2 committed
59
60
61
62
63
64
65
66
67
68
69
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

Fabrizio Milo's avatar
Fabrizio Milo committed
70
                text = ob["text"]
researcher2's avatar
researcher2 committed
71
72
73
74
75

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
Fabrizio Milo's avatar
Fabrizio Milo committed
76
                    yield text, (ob["meta"] if "meta" in ob else {})
researcher2's avatar
researcher2 committed
77
78
79
                else:
                    yield text

Fabrizio Milo's avatar
Fabrizio Milo committed
80

researcher2's avatar
researcher2 committed
81
class TextArchive:
Ethan Smith's avatar
Ethan Smith committed
82
    def __init__(self, file_path, mode: str = "rb+") -> None:
researcher2's avatar
researcher2 committed
83
84
85
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
86
            os.makedirs(dir_name, exist_ok=True)
researcher2's avatar
researcher2 committed
87
88
89

        if not os.path.exists(file_path):
            Path(file_path).touch()
Fabrizio Milo's avatar
Fabrizio Milo committed
90
91
92

        self.fh = open(self.file_path, mode)

Ethan Smith's avatar
Ethan Smith committed
93
    def add_data(self, data) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
94
95
        self.fh.write(data.encode("UTF-8") + b"\n")

Ethan Smith's avatar
Ethan Smith committed
96
    def commit(self) -> None:
researcher2's avatar
researcher2 committed
97
98
99
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
100

researcher2's avatar
researcher2 committed
101
class TextReader:
Ethan Smith's avatar
Ethan Smith committed
102
    def __init__(self, file_path) -> None:
researcher2's avatar
researcher2 committed
103
104
        self.file_path = file_path

105
    # Optimized mmap read with infrequent tqdm updates to maintain speed
Fabrizio Milo's avatar
Fabrizio Milo committed
106
    # Tested up to 250MB/s.
Ethan Smith's avatar
Ethan Smith committed
107
    def read_tqdm(self, update_frequency: int = 10000):
108
109
        current_file_position = 0
        line_counter = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
110
111
112
113
114
115
        with open(self.file_path, "r") as fh, tqdm.tqdm(
            total=os.path.getsize(self.file_path),
            dynamic_ncols=True,
            unit="byte",
            unit_scale=1,
        ) as progress:
116
117
118
119
120
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
Fabrizio Milo's avatar
Fabrizio Milo committed
121
                        new_file_pos = mmap_obj.tell()
122
123
124
125
126
127
128
129
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
130
        with open(self.file_path, "r", encoding="utf8") as fh:
131
132
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
133
134
                    line = line.decode("utf-8")
                    new_file_pos = mmap_obj.tell()
135
136
137
138
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
139
    def read(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
140
        with open(self.file_path, "r", encoding="utf8") as fh:
141
142
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
143
                    line = line.decode("utf-8")
144
145
146
                    yield line[:-1]

    def read_slow(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
147
        with open(self.file_path, "r", encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
148
            while True:
149
                line = fh.readline()
researcher2's avatar
researcher2 committed
150
151
                if line == -1 or line == "":
                    break
152
153
154
                else:
                    yield line[:-1]

Fabrizio Milo's avatar
Fabrizio Milo committed
155

156
157
158
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
Ethan Smith's avatar
Ethan Smith committed
159
    def __init__(self, file) -> None:
Fabrizio Milo's avatar
Fabrizio Milo committed
160
        self.file = file
161

Fabrizio Milo's avatar
Fabrizio Milo committed
162
    def read_tqdm(self):
163
164
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
165
        os.system(f"zstd -d {self.file}")  # linux decompress is faster
166
167
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
Fabrizio Milo's avatar
Fabrizio Milo committed
168
        os.remove(decompressed_file)