lazy_loader.py 6.78 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for loading text from disk"""
import os
import mmap
import pickle as pkl
import time
from itertools import accumulate

import torch
23
from torch.multiprocessing import Lock
Raul Puri's avatar
Raul Puri committed
24

Neel Kant's avatar
Neel Kant committed
25

Raul Puri's avatar
Raul Puri committed
26
27
28
29
def get_lazy_path(path):
    """
    Gets directory path where lazy files are stored.
    """
Neel Kant's avatar
Neel Kant committed
30
31
    return os.path.splitext(path)[0] + '.lazy'

Raul Puri's avatar
Raul Puri committed
32
33
34
35
36
37
38
39
40
41

def exists_lazy(path, data_type='data'):
    """
    Check if we've already made a lazy version of this file for the `data_type` field.
    """
    if not os.path.exists(get_lazy_path(path)):
        return False
    contents = os.listdir(get_lazy_path(path))
    if data_type not in contents:
        return False
Neel Kant's avatar
Neel Kant committed
42
    if data_type + '.len.pkl' not in contents:
Raul Puri's avatar
Raul Puri committed
43
44
45
        return False
    return True

Neel Kant's avatar
Neel Kant committed
46

Raul Puri's avatar
Raul Puri committed
47
48
49
50
51
52
53
54
55
def make_lazy(path, strs, data_type='data'):
    """
    Make lazy version of `data_type` field of the file. Byte offsets
    corresponding to data indices are stored in a `.len.pkl` data file.
    """
    lazypath = get_lazy_path(path)
    if not os.path.exists(lazypath):
        os.makedirs(lazypath)
    datapath = os.path.join(lazypath, data_type)
Neel Kant's avatar
Neel Kant committed
56
    lenpath = os.path.join(lazypath, data_type + '.len.pkl')
Raul Puri's avatar
Raul Puri committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        with open(datapath, 'wb') as f:
            str_lens = []
            str_cnt = 0
            for s in strs:
                if isinstance(s, dict):
                    s = s['text']
                encoded = s.encode('utf-8')
                f.write(encoded)
                str_cnt = len(encoded)
                str_lens.append(str_cnt)
        pkl.dump(str_lens, open(lenpath, 'wb'))
    else:
        while not os.path.exists(lenpath):
            time.sleep(1)

Neel Kant's avatar
Neel Kant committed
73

Raul Puri's avatar
Raul Puri committed
74
75
76
77
def split_strings(strings, start, chr_lens):
    """
    Split strings based on string lengths and given start.
    """
Neel Kant's avatar
Neel Kant committed
78
79
    return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]

Raul Puri's avatar
Raul Puri committed
80
81
82
83
84
85

class ProcessorTokenizer:
    """
    callable class that runs a preprocessing, as well as tokenization step,
    on input text.
    """
Neel Kant's avatar
Neel Kant committed
86

Raul Puri's avatar
Raul Puri committed
87
88
89
90
91
92
    def __init__(self, tokenizer, process_fn=None):
        self.tokenizer = tokenizer
        self.process_fn = process_fn

    def __call__(self, string):
        if self.tokenizer is not None:
Neel Kant's avatar
Neel Kant committed
93
            string = self.tokenizer(string, process_fn=self.process_fn)
Raul Puri's avatar
Raul Puri committed
94
        elif self.process_fn is not None:
Neel Kant's avatar
Neel Kant committed
95
            string = self.process_fn(string)
Raul Puri's avatar
Raul Puri committed
96
97
        return string

Neel Kant's avatar
Neel Kant committed
98

Raul Puri's avatar
Raul Puri committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class lazy_array_loader(object):
    """
    Arguments:
        path: path to directory where array entries are concatenated into one big string file
            and the .len file are located
        data_type (str): Some datsets have multiple fields that are stored in different paths.
            `data_type` specifies which of these fields to load in this class
        mem_map  (boolean): Specifies whether to memory map file `path`
        map_fn (callable): Fetched strings are passed through map_fn before being returned.

    Example of lazy loader directory structure:
    file.json
    file.lazy/
        data_type1
        data_type1.len.pkl
        data_type2
        data_type2.len.pkl
    """
Neel Kant's avatar
Neel Kant committed
117

Raul Puri's avatar
Raul Puri committed
118
119
120
    def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
        lazypath = get_lazy_path(path)
        datapath = os.path.join(lazypath, data_type)
Neel Kant's avatar
Neel Kant committed
121
        # get file where array entries are concatenated into one big string
122
        self._file = open(datapath, 'rb', buffering=0)
Raul Puri's avatar
Raul Puri committed
123
        self.file = self._file
Neel Kant's avatar
Neel Kant committed
124
        # memory map file if necessary
Raul Puri's avatar
Raul Puri committed
125
126
127
        self.mem_map = mem_map
        if self.mem_map:
            self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
Neel Kant's avatar
Neel Kant committed
128
        lenpath = os.path.join(lazypath, data_type + '.len.pkl')
Raul Puri's avatar
Raul Puri committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        self.lens = pkl.load(open(lenpath, 'rb'))
        self.ends = list(accumulate(self.lens))
        self.dumb_ends = list(self.ends)
        self.read_lock = Lock()
        self.process_fn = map_fn
        self.map_fn = map_fn
        self._tokenizer = None

    def SetTokenizer(self, tokenizer):
        """
        logic to set and remove (set to None) tokenizer.
        combines preprocessing/tokenization into one callable.
        """
        if tokenizer is None:
            if not hasattr(self, '_tokenizer'):
                self._tokenizer = tokenizer
        else:
            self._tokenizer = tokenizer
        self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)

    def GetTokenizer(self):
        return self._tokenizer

    def __getitem__(self, index):
        """
        read file and splice strings based on string ending array `self.ends`
        """
        if not isinstance(index, slice):
            if index == 0:
                start = 0
            else:
Neel Kant's avatar
Neel Kant committed
160
                start = self.ends[index - 1]
Raul Puri's avatar
Raul Puri committed
161
162
163
164
165
166
167
168
169
170
            end = self.ends[index]
            rtn = self.file_read(start, end)
            if self.map_fn is not None:
                return self.map_fn(rtn)
        else:
            # if slice, fetch strings with 1 diskread and then splice in memory
            chr_lens = self.ends[index]
            if index.start == 0 or index.start is None:
                start = 0
            else:
Neel Kant's avatar
Neel Kant committed
171
                start = self.ends[index.start - 1]
Raul Puri's avatar
Raul Puri committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            stop = chr_lens[-1]
            strings = self.file_read(start, stop)
            rtn = split_strings(strings, start, chr_lens)
            if self.map_fn is not None:
                return self.map_fn([s for s in rtn])
        return rtn

    def __len__(self):
        return len(self.ends)

    def file_read(self, start=0, end=None):
        """read specified portion of file"""

        # atomic reads to avoid race conditions with multiprocess dataloader
        self.read_lock.acquire()
        # seek to start of file read
        self.file.seek(start)
        # read to end of file if no end point provided
        if end is None:
            rtn = self.file.read()
Neel Kant's avatar
Neel Kant committed
192
        # else read amount needed to reach end point
Raul Puri's avatar
Raul Puri committed
193
        else:
Neel Kant's avatar
Neel Kant committed
194
            rtn = self.file.read(end - start)
Raul Puri's avatar
Raul Puri committed
195
        self.read_lock.release()
Neel Kant's avatar
Neel Kant committed
196
197
        # TODO: @raulp figure out mem map byte string bug
        # if mem map'd need to decode byte string to string
198
        rtn = rtn.decode('utf-8', 'ignore')
Raul Puri's avatar
Raul Puri committed
199
200
201
202
        # rtn = str(rtn)
        if self.mem_map:
            rtn = rtn.decode('unicode_escape')
        return rtn