archiver.py 5.33 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
import os
import zstandard
import json
import jsonlines
import io
import datetime
7
8
import mmap
import tqdm
researcher2's avatar
researcher2 committed
9
from pathlib import Path
researcher2's avatar
researcher2 committed
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

researcher2's avatar
researcher2 committed
12
13
14
15
16
def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
Fabrizio Milo's avatar
Fabrizio Milo committed
17
18
    raise TypeError("Type %s not serializable" % type(obj))

researcher2's avatar
researcher2 committed
19
20
21
22
23
24
25

# Modified version of lm_dataformat Archive for single file.
class Archive:
    def __init__(self, file_path, compression_level=3):
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
26
27
            os.makedirs(dir_name, exist_ok=True)
        self.fh = open(self.file_path, "wb")
researcher2's avatar
researcher2 committed
28
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
        self.compressor = self.cctx.stream_writer(self.fh)

researcher2's avatar
researcher2 committed
31
    def add_data(self, data, meta={}):
Fabrizio Milo's avatar
Fabrizio Milo committed
32
33
34
35
36
37
38
        self.compressor.write(
            json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
                "UTF-8"
            )
            + b"\n"
        )

researcher2's avatar
researcher2 committed
39
    def commit(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
40
        self.compressor.flush(zstandard.FLUSH_FRAME)
researcher2's avatar
researcher2 committed
41
42
43
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
44

researcher2's avatar
researcher2 committed
45
46
47
48
49
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
    def __init__(self):
        pass

Fabrizio Milo's avatar
Fabrizio Milo committed
50
51
    def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
        with open(file, "rb") as fh:
researcher2's avatar
researcher2 committed
52
53
54
55
56
57
58
59
60
61
62
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

Fabrizio Milo's avatar
Fabrizio Milo committed
63
                text = ob["text"]
researcher2's avatar
researcher2 committed
64
65
66
67
68

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
Fabrizio Milo's avatar
Fabrizio Milo committed
69
                    yield text, (ob["meta"] if "meta" in ob else {})
researcher2's avatar
researcher2 committed
70
71
72
                else:
                    yield text

Fabrizio Milo's avatar
Fabrizio Milo committed
73

researcher2's avatar
researcher2 committed
74
class TextArchive:
researcher2's avatar
researcher2 committed
75
    def __init__(self, file_path, mode="rb+"):
researcher2's avatar
researcher2 committed
76
77
78
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
Fabrizio Milo's avatar
Fabrizio Milo committed
79
            os.makedirs(dir_name, exist_ok=True)
researcher2's avatar
researcher2 committed
80
81
82

        if not os.path.exists(file_path):
            Path(file_path).touch()
Fabrizio Milo's avatar
Fabrizio Milo committed
83
84
85

        self.fh = open(self.file_path, mode)

researcher2's avatar
researcher2 committed
86
    def add_data(self, data):
Fabrizio Milo's avatar
Fabrizio Milo committed
87
88
        self.fh.write(data.encode("UTF-8") + b"\n")

researcher2's avatar
researcher2 committed
89
90
91
92
    def commit(self):
        self.fh.flush()
        self.fh.close()

Fabrizio Milo's avatar
Fabrizio Milo committed
93

researcher2's avatar
researcher2 committed
94
95
96
97
class TextReader:
    def __init__(self, file_path):
        self.file_path = file_path

98
    # Optimized mmap read with infrequent tqdm updates to maintain speed
Fabrizio Milo's avatar
Fabrizio Milo committed
99
    # Tested up to 250MB/s.
100
101
102
    def read_tqdm(self, update_frequency=10000):
        current_file_position = 0
        line_counter = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
103
104
105
106
107
108
        with open(self.file_path, "r") as fh, tqdm.tqdm(
            total=os.path.getsize(self.file_path),
            dynamic_ncols=True,
            unit="byte",
            unit_scale=1,
        ) as progress:
109
110
111
112
113
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
Fabrizio Milo's avatar
Fabrizio Milo committed
114
                        new_file_pos = mmap_obj.tell()
115
116
117
118
119
120
121
122
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
Fabrizio Milo's avatar
Fabrizio Milo committed
123
        with open(self.file_path, "r", encoding="utf8") as fh:
124
125
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
126
127
                    line = line.decode("utf-8")
                    new_file_pos = mmap_obj.tell()
128
129
130
131
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
132
    def read(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
133
        with open(self.file_path, "r", encoding="utf8") as fh:
134
135
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
Fabrizio Milo's avatar
Fabrizio Milo committed
136
                    line = line.decode("utf-8")
137
138
139
                    yield line[:-1]

    def read_slow(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
140
        with open(self.file_path, "r", encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
141
            while True:
142
                line = fh.readline()
researcher2's avatar
researcher2 committed
143
144
                if line == -1 or line == "":
                    break
145
146
147
                else:
                    yield line[:-1]

Fabrizio Milo's avatar
Fabrizio Milo committed
148

149
150
151
152
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
    def __init__(self, file):
Fabrizio Milo's avatar
Fabrizio Milo committed
153
        self.file = file
154

Fabrizio Milo's avatar
Fabrizio Milo committed
155
    def read_tqdm(self):
156
157
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
158
        os.system(f"zstd -d {self.file}")  # linux decompress is faster
159
160
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
Fabrizio Milo's avatar
Fabrizio Milo committed
161
        os.remove(decompressed_file)