archiver.py 5.37 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
import os
import zstandard
import json
import jsonlines
import io
import datetime
7
8
import mmap
import tqdm
researcher2's avatar
researcher2 committed
9
from pathlib import Path
researcher2's avatar
researcher2 committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime,)):
        return obj.isoformat()
    raise TypeError ("Type %s not serializable" % type(obj))

# Modified version of lm_dataformat Archive for single file.
class Archive:
    def __init__(self, file_path, compression_level=3):
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)    
        self.fh = open(self.file_path, 'wb')
        self.cctx = zstandard.ZstdCompressor(level=compression_level)
        self.compressor = self.cctx.stream_writer(self.fh)        
    
    def add_data(self, data, meta={}):
        self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
    
    def commit(self):
        self.compressor.flush(zstandard.FLUSH_FRAME)        
        self.fh.flush()
        self.fh.close()

# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
    def __init__(self):
        pass

    def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
        with open(file, 'rb') as fh:
            self.fh = fh
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            for ob in rdr:
                # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
                if isinstance(ob, str):
                    assert not get_meta
                    yield ob
                    continue

                text = ob['text']

                if autojoin_paragraphs and isinstance(text, list):
                    text = para_joiner.join(text)

                if get_meta:
                    yield text, (ob['meta'] if 'meta' in ob else {})
                else:
                    yield text

class TextArchive:
researcher2's avatar
researcher2 committed
66
    def __init__(self, file_path, mode="rb+"):
researcher2's avatar
researcher2 committed
67
68
69
70
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)    
researcher2's avatar
researcher2 committed
71
72
73
74
75

        if not os.path.exists(file_path):
            Path(file_path).touch()
                        
        self.fh = open(self.file_path, mode)    
researcher2's avatar
researcher2 committed
76
    
researcher2's avatar
researcher2 committed
77
    def add_data(self, data):
researcher2's avatar
researcher2 committed
78
79
80
81
82
83
84
85
86
87
        self.fh.write(data.encode('UTF-8') + b'\n')
    
    def commit(self):
        self.fh.flush()
        self.fh.close()

class TextReader:
    def __init__(self, file_path):
        self.file_path = file_path

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    # Optimized mmap read with infrequent tqdm updates to maintain speed
    # Tested up to 250MB/s. 
    def read_tqdm(self, update_frequency=10000):
        current_file_position = 0
        line_counter = 0
        with open(self.file_path, 'r') as fh, \
            tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True, 
                unit="byte", unit_scale=1) as progress:
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
                        new_file_pos = mmap_obj.tell() 
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
        with open(self.file_path, 'r', encoding="utf8") as fh:
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")                    
                    new_file_pos = mmap_obj.tell() 
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

researcher2's avatar
researcher2 committed
119
120
    def read(self):
        with open(self.file_path, 'r', encoding="utf8") as fh:
121
122
123
124
125
126
127
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")                    
                    yield line[:-1]

    def read_slow(self):
        with open(self.file_path, 'r', encoding="utf8") as fh:
researcher2's avatar
researcher2 committed
128
            while True:
129
                line = fh.readline()
researcher2's avatar
researcher2 committed
130
131
                if line == -1 or line == "":
                    break
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                else:
                    yield line[:-1]

# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
    def __init__(self, file):
        self.file = file        

    def read_tqdm(self):       
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
        os.system(f"zstd -d {self.file}") # linux decompress is faster        
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
        os.remove(decompressed_file)