multi_hot_criteo.py 12.6 KB
Newer Older
xinghao's avatar
xinghao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import zipfile
from collections.abc import Iterator

import numpy as np
import torch
from iopath.common.file_io import PathManager, PathManagerFactory
from pyre_extensions import none_throws
from torch.utils.data import IterableDataset
from torchrec.datasets.criteo import CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES
from torchrec.datasets.utils import Batch, PATH_MANAGER_KEY
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class MultiHotCriteoIterDataPipe(IterableDataset):
    """
    Datapipe designed to operate over the MLPerf DLRM v2 synthetic multi-hot dataset.
    This dataset can be created by following the steps in
    torchrec_dlrm/scripts/materialize_synthetic_multihot_dataset.py.
    Each rank reads only the data for the portion of the dataset it is responsible for.

    Args:
        stage (str): "train", "val", or "test".
        dense_paths (List[str]): List of path strings to dense npy files.
        sparse_paths (List[str]): List of path strings to multi-hot sparse npz files.
        labels_paths (List[str]): List of path strings to labels npy files.
        batch_size (int): batch size.
        rank (int): rank.
        world_size (int): world size.
        drop_last (Optional[bool]): Whether to drop the last batch if it is incomplete.
        shuffle_batches (bool): Whether to shuffle batches
        shuffle_training_set (bool): Whether to shuffle all samples in the dataset.
        shuffle_training_set_random_seed (int): The random generator seed used when
            shuffling the training set.
        hashes (Optional[int]): List of max categorical feature value for each feature.
            Length of this list should be CAT_FEATURE_COUNT.
        path_manager_key (str): Path manager key used to load from different
            filesystems.

    Example::

        datapipe = MultiHotCriteoIterDataPipe(
            dense_paths=["day_0_dense.npy"],
            sparse_paths=["day_0_sparse_multi_hot.npz"],
            labels_paths=["day_0_labels.npy"],
            batch_size=1024,
            rank=torch.distributed.get_rank(),
            world_size=torch.distributed.get_world_size(),
        )
        batch = next(iter(datapipe))
    """

    def __init__(
        self,
        stage: str,
        dense_paths: list[str],
        sparse_paths: list[str],
        labels_paths: list[str],
        batch_size: int,
        rank: int,
        world_size: int,
        drop_last: bool | None = False,
        shuffle_batches: bool = False,
        shuffle_training_set: bool = False,
        shuffle_training_set_random_seed: int = 0,
        mmap_mode: bool = False,
        hashes: list[int] | None = None,
        path_manager_key: str = PATH_MANAGER_KEY,
    ) -> None:
        self.stage = stage
        self.dense_paths = dense_paths
        self.sparse_paths = sparse_paths
        self.labels_paths = labels_paths
        self.batch_size = batch_size
        self.rank = rank
        self.world_size = world_size
        self.drop_last = drop_last
        self.shuffle_batches = shuffle_batches
        self.shuffle_training_set = shuffle_training_set
        np.random.seed(shuffle_training_set_random_seed)
        self.mmap_mode = mmap_mode
        # hashes are not used because they were already applied in the
        # script that generates the multi-hot dataset.
        self.hashes: np.ndarray = np.array(hashes).reshape((CAT_FEATURE_COUNT, 1))
        self.path_manager_key = path_manager_key
        self.path_manager: PathManager = PathManagerFactory().get(path_manager_key)

        if shuffle_training_set and stage == "train":
            # Currently not implemented for the materialized multi-hot dataset.
            self._shuffle_and_load_data_for_rank()
        else:
            m = "r" if mmap_mode else None
            self.dense_arrs: list[np.ndarray] = [
                np.load(f, mmap_mode=m) for f in self.dense_paths
            ]
            self.labels_arrs: list[np.ndarray] = [
                np.load(f, mmap_mode=m) for f in self.labels_paths
            ]
            self.sparse_arrs: list = []
            for sparse_path in self.sparse_paths:
                multi_hot_ids_l = []
                for feat_id_num in range(CAT_FEATURE_COUNT):
                    multi_hot_ft_ids = self._load_from_npz(
                        sparse_path, f"{feat_id_num}.npy"
                    )
                    multi_hot_ids_l.append(multi_hot_ft_ids)
                self.sparse_arrs.append(multi_hot_ids_l)
        len_d0 = len(self.dense_arrs[0])
        second_half_start_index = int(len_d0 // 2 + len_d0 % 2)
        if stage == "val":
            self.dense_arrs[0] = self.dense_arrs[0][:second_half_start_index, :]
            self.labels_arrs[0] = self.labels_arrs[0][:second_half_start_index, :]
            self.sparse_arrs[0] = [
                feats[:second_half_start_index, :] for feats in self.sparse_arrs[0]
            ]
        elif stage == "test":
            self.dense_arrs[0] = self.dense_arrs[0][second_half_start_index:, :]
            self.labels_arrs[0] = self.labels_arrs[0][second_half_start_index:, :]
            self.sparse_arrs[0] = [
                feats[second_half_start_index:, :] for feats in self.sparse_arrs[0]
            ]
        # When mmap_mode is enabled, sparse features are hashed when
        # samples are batched in def __iter__. Otherwise, the dataset has been
        # preloaded with sparse features hashed in the preload stage, here:
        # if not self.mmap_mode and self.hashes is not None:
        #     for k, _ in enumerate(self.sparse_arrs):
        #         self.sparse_arrs[k] = [
        #             feat % hash
        #             for (feat, hash) in zip(self.sparse_arrs[k], self.hashes)
        #         ]

        self.num_rows_per_file: list[int] = list(map(len, self.dense_arrs))
        total_rows = sum(self.num_rows_per_file)
        self.num_full_batches: int = (
            total_rows // batch_size // self.world_size * self.world_size
        )
        self.last_batch_sizes: np.ndarray = np.array(
            [0 for _ in range(self.world_size)]
        )
        remainder = total_rows % (self.world_size * batch_size)
        if not self.drop_last and 0 < remainder:
            if remainder < self.world_size:
                self.num_full_batches -= self.world_size
                self.last_batch_sizes += batch_size
            else:
                self.last_batch_sizes += remainder // self.world_size
            self.last_batch_sizes[: remainder % self.world_size] += 1

        self.multi_hot_sizes: list[int] = [
            multi_hot_feat.shape[-1] for multi_hot_feat in self.sparse_arrs[0]
        ]

        # These values are the same for the KeyedJaggedTensors in all batches, so they
        # are computed once here. This avoids extra work from the KeyedJaggedTensor sync
        # functions.
        self.keys: list[str] = DEFAULT_CAT_NAMES
        self.index_per_key: dict[str, int] = {
            key: i for (i, key) in enumerate(self.keys)
        }

    def _load_from_npz(self, fname, npy_name):
        # figure out offset of .npy in .npz
        zf = zipfile.ZipFile(fname)
        info = zf.NameToInfo[npy_name]
        assert info.compress_type == 0
        zf.fp.seek(info.header_offset + len(info.FileHeader()) + 20)
        # read .npy header
        zf.open(npy_name, "r")
        version = np.lib.format.read_magic(zf.fp)
        shape, fortran_order, dtype = np.lib.format._read_array_header(zf.fp, version)
        assert (
            dtype == "int32"
        ), f"sparse multi-hot dtype is {dtype} but should be int32"
        offset = zf.fp.tell()
        # create memmap
        return np.memmap(
            zf.filename,
            dtype=dtype,
            shape=shape,
            order="F" if fortran_order else "C",
            mode="r",
            offset=offset,
        )

    def _np_arrays_to_batch(
        self,
        dense: np.ndarray,
        sparse: list[np.ndarray],
        labels: np.ndarray,
    ) -> Batch:
        if self.shuffle_batches:
            # Shuffle all 3 in unison
            shuffler = np.random.permutation(len(dense))
            sparse = [multi_hot_ft[shuffler, :] for multi_hot_ft in sparse]
            dense = dense[shuffler]
            labels = labels[shuffler]

        batch_size = len(dense)
        lengths = torch.ones((CAT_FEATURE_COUNT * batch_size), dtype=torch.int32)
        for k, multi_hot_size in enumerate(self.multi_hot_sizes):
            lengths[k * batch_size : (k + 1) * batch_size] = multi_hot_size
        offsets = torch.cumsum(torch.concat((torch.tensor([0]), lengths)), dim=0)
        length_per_key = [
            batch_size * multi_hot_size for multi_hot_size in self.multi_hot_sizes
        ]
        offset_per_key = torch.cumsum(
            torch.concat((torch.tensor([0]), torch.tensor(length_per_key))), dim=0
        )
        values = torch.concat([torch.from_numpy(feat).flatten() for feat in sparse])
        return Batch(
            dense_features=torch.from_numpy(dense.copy()),
            sparse_features=KeyedJaggedTensor(
                keys=self.keys,
                values=values,
                lengths=lengths,
                offsets=offsets,
                stride=batch_size,
                length_per_key=length_per_key,
                offset_per_key=offset_per_key.tolist(),
                index_per_key=self.index_per_key,
            ),
            labels=torch.from_numpy(labels.reshape(-1).copy()),
        )

    def __iter__(self) -> Iterator[Batch]:
        # Invariant: buffer never contains more than batch_size rows.
        buffer: list[np.ndarray] | None = None

        def append_to_buffer(
            dense: np.ndarray,
            sparse: list[np.ndarray],
            labels: np.ndarray,
        ) -> None:
            nonlocal buffer
            if buffer is None:
                buffer = [dense, sparse, labels]
            else:
                buffer[0] = np.concatenate((buffer[0], dense))
                buffer[1] = [np.concatenate((b, s)) for b, s in zip(buffer[1], sparse)]
                buffer[2] = np.concatenate((buffer[2], labels))

        # Maintain a buffer that can contain up to batch_size rows. Fill buffer as
        # much as possible on each iteration. Only return a new batch when batch_size
        # rows are filled.
        file_idx = 0
        row_idx = 0
        batch_idx = 0
        buffer_row_count = 0
        cur_batch_size = (
            self.batch_size if self.num_full_batches > 0 else self.last_batch_sizes[0]
        )
        while (
            batch_idx
            < self.num_full_batches + (self.last_batch_sizes[0] > 0) * self.world_size
        ):
            if buffer_row_count == cur_batch_size or file_idx == len(self.dense_arrs):
                if batch_idx % self.world_size == self.rank:
                    yield self._np_arrays_to_batch(*none_throws(buffer))
                    buffer = None
                buffer_row_count = 0
                batch_idx += 1
                if 0 <= batch_idx - self.num_full_batches < self.world_size and (
                    self.last_batch_sizes[0] > 0
                ):
                    cur_batch_size = self.last_batch_sizes[
                        batch_idx - self.num_full_batches
                    ]
            else:
                rows_to_get = min(
                    cur_batch_size - buffer_row_count,
                    self.num_rows_per_file[file_idx] - row_idx,
                )
                buffer_row_count += rows_to_get
                slice_ = slice(row_idx, row_idx + rows_to_get)

                if batch_idx % self.world_size == self.rank:
                    dense_inputs = self.dense_arrs[file_idx][slice_, :]
                    sparse_inputs = [
                        feats[slice_, :] for feats in self.sparse_arrs[file_idx]
                    ]
                    target_labels = self.labels_arrs[file_idx][slice_, :]

                    # if self.mmap_mode and self.hashes is not None:
                    #     sparse_inputs = [
                    #         feats % hash
                    #         for (feats, hash) in zip(sparse_inputs, self.hashes)
                    #     ]

                    append_to_buffer(
                        dense_inputs,
                        sparse_inputs,
                        target_labels,
                    )
                row_idx += rows_to_get

                if row_idx >= self.num_rows_per_file[file_idx]:
                    file_idx += 1
                    row_idx = 0

    def __len__(self) -> int:
        return self.num_full_batches // self.world_size + (self.last_batch_sizes[0] > 0)