test_item_sampler.py 42.8 KB
Newer Older
1
import os
2
import re
3
import unittest
4
from collections import defaultdict
5
from sys import platform
6

7
8
import backend as F

9
10
11
import dgl
import pytest
import torch
12
13
import torch.distributed as dist
import torch.multiprocessing as mp
14
15
16
from dgl import graphbolt as gb


17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def test_ItemSampler_minibatcher():
    # Default minibatcher is used if not specified.
    # Warning message is raised if names are not specified.
    item_set = gb.ItemSet(torch.arange(0, 10))
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    with pytest.warns(
        UserWarning,
        match=re.escape(
            "Failed to map item list to `MiniBatch` as the names of items are "
            "not provided. Please provide a customized `MiniBatcher`. The "
            "item list is returned as is."
        ),
    ):
        minibatch = next(iter(item_sampler))
        assert not isinstance(minibatch, gb.MiniBatch)

    # Default minibatcher is used if not specified.
    # Warning message is raised if unrecognized names are specified.
    item_set = gb.ItemSet(torch.arange(0, 10), names="unknown_name")
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    with pytest.warns(
        UserWarning,
        match=re.escape(
            "Unknown item name 'unknown_name' is detected and added into "
            "`MiniBatch`. You probably need to provide a customized "
            "`MiniBatcher`."
        ),
    ):
        minibatch = next(iter(item_sampler))
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.unknown_name is not None

    # Default minibatcher is used if not specified.
    # `MiniBatch` is returned if expected names are specified.
51
    item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
52
53
54
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    minibatch = next(iter(item_sampler))
    assert isinstance(minibatch, gb.MiniBatch)
55
56
    assert minibatch.seeds is not None
    assert len(minibatch.seeds) == 4
57
58
59

    # Customized minibatcher is used if specified.
    def minibatcher(batch, names):
60
        return gb.MiniBatch(seeds=batch)
61
62
63
64
65
66

    item_sampler = gb.ItemSampler(
        item_set, batch_size=4, minibatcher=minibatcher
    )
    minibatch = next(iter(item_sampler))
    assert isinstance(minibatch, gb.MiniBatch)
67
68
    assert minibatch.seeds is not None
    assert len(minibatch.seeds) == 4
69
70


71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
    num_ids = 103

    class InvalidLength:
        def __iter__(self):
            return iter(torch.arange(0, num_ids))

    seed_nodes = gb.ItemSet(InvalidLength())
82
    item_set = gb.ItemSet(seed_nodes, names="seeds")
83
84
85
86
87
88
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
89
        assert minibatch.seeds is not None
90
91
92
        assert minibatch.labels is None
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
93
            assert len(minibatch.seeds) == batch_size
94
95
        else:
            if not drop_last:
96
                assert len(minibatch.seeds) == num_ids % batch_size
97
98
            else:
                assert False
99
        minibatch_ids.append(minibatch.seeds)
100
101
102
103
104
105
106
107
108
109
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_integer(batch_size, shuffle, drop_last):
    # Node IDs.
    num_ids = 103
110
    item_set = gb.ItemSet(num_ids, names="seeds")
111
112
113
114
115
116
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
117
        assert minibatch.seeds is not None
118
119
120
        assert minibatch.labels is None
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
121
            assert len(minibatch.seeds) == batch_size
122
123
        else:
            if not drop_last:
124
                assert len(minibatch.seeds) == num_ids % batch_size
125
126
            else:
                assert False
127
        minibatch_ids.append(minibatch.seeds)
128
129
130
131
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


132
133
134
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
135
def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
136
    # Node IDs.
137
    num_ids = 103
138
    seed_nodes = torch.arange(0, num_ids)
139
    item_set = gb.ItemSet(seed_nodes, names="seeds")
140
    item_sampler = gb.ItemSampler(
141
142
143
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
144
    for i, minibatch in enumerate(item_sampler):
145
        assert isinstance(minibatch, gb.MiniBatch)
146
        assert minibatch.seeds is not None
147
        assert minibatch.labels is None
148
149
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
150
            assert len(minibatch.seeds) == batch_size
151
152
        else:
            if not drop_last:
153
                assert len(minibatch.seeds) == num_ids % batch_size
154
155
            else:
                assert False
156
        minibatch_ids.append(minibatch.seeds)
157
158
159
160
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


161
162
163
164
165
166
167
168
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
    # Node IDs.
    num_ids = 103
    seed_nodes = torch.arange(0, num_ids)
    labels = torch.arange(0, num_ids)
169
    item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
170
171
172
173
174
175
176
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    minibatch_labels = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
177
        assert minibatch.seeds is not None
178
        assert minibatch.labels is not None
179
        assert len(minibatch.seeds) == len(minibatch.labels)
180
181
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
182
            assert len(minibatch.seeds) == batch_size
183
184
        else:
            if not drop_last:
185
                assert len(minibatch.seeds) == num_ids % batch_size
186
187
            else:
                assert False
188
        minibatch_ids.append(minibatch.seeds)
189
190
191
192
193
194
195
196
197
        minibatch_labels.append(minibatch.labels)
    minibatch_ids = torch.cat(minibatch_ids)
    minibatch_labels = torch.cat(minibatch_labels)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
    assert (
        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
    )


198
199
200
201
202
203
204
205
206
207
208
209
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_graphs(batch_size, shuffle, drop_last):
    # Graphs.
    num_graphs = 103
    num_nodes = 10
    num_edges = 20
    graphs = [
        dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1))
        for i in range(num_graphs)
    ]
210
211
212
213
    item_set = gb.ItemSet(graphs, names="graphs")
    # DGLGraph is not supported in gb.MiniBatch yet. Let's use a customized
    # minibatcher to return the original graphs.
    customized_minibatcher = lambda batch, names: batch
214
    item_sampler = gb.ItemSampler(
215
216
217
218
219
        item_set,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        minibatcher=customized_minibatcher,
220
221
222
    )
    minibatch_num_nodes = []
    minibatch_num_edges = []
223
    for i, minibatch in enumerate(item_sampler):
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        is_last = (i + 1) * batch_size >= num_graphs
        if not is_last or num_graphs % batch_size == 0:
            assert minibatch.batch_size == batch_size
        else:
            if not drop_last:
                assert minibatch.batch_size == num_graphs % batch_size
            else:
                assert False
        minibatch_num_nodes.append(minibatch.batch_num_nodes())
        minibatch_num_edges.append(minibatch.batch_num_edges())
    minibatch_num_nodes = torch.cat(minibatch_num_nodes)
    minibatch_num_edges = torch.cat(minibatch_num_edges)
    assert (
        torch.all(minibatch_num_nodes[:-1] <= minibatch_num_nodes[1:])
        is not shuffle
    )
    assert (
        torch.all(minibatch_num_edges[:-1] <= minibatch_num_edges[1:])
        is not shuffle
    )


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
    # Node pairs.
    num_ids = 103
252
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
253
    item_set = gb.ItemSet(node_pairs, names="seeds")
254
    item_sampler = gb.ItemSampler(
255
256
257
258
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
259
    for i, minibatch in enumerate(item_sampler):
260
261
        assert minibatch.seeds is not None
        assert isinstance(minibatch.seeds, torch.Tensor)
262
        assert minibatch.labels is None
263
        src, dst = minibatch.seeds.T
264
265
266
267
268
269
270
271
272
273
274
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        # Verify src and dst IDs match.
275
        assert torch.equal(src + 1, dst)
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # Archive batch.
        src_ids.append(src)
        dst_ids.append(dst)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
    # Node pairs and labels
    num_ids = 103
291
292
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
    labels = node_pairs[:, 0]
293
    item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
294
    item_sampler = gb.ItemSampler(
295
296
297
298
299
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
    labels = []
300
    for i, minibatch in enumerate(item_sampler):
301
302
        assert minibatch.seeds is not None
        assert isinstance(minibatch.seeds, torch.Tensor)
303
        assert minibatch.labels is not None
304
        src, dst = minibatch.seeds.T
305
        label = minibatch.labels
306
307
        assert len(src) == len(dst)
        assert len(src) == len(label)
308
309
310
311
312
313
314
315
316
317
318
319
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        assert len(label) == expected_batch_size
        # Verify src/dst IDs and labels match.
320
        assert torch.equal(src + 1, dst)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        assert torch.equal(src, label)
        # Archive batch.
        src_ids.append(src)
        dst_ids.append(dst)
        labels.append(label)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    labels = torch.cat(labels)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
337
def test_ItemSet_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
338
    # Node pairs and negative destinations.
339
340
    num_ids = 103
    num_negs = 2
341
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
342
343
344
345
346
347
348
349
350
351
352
353
354
    neg_srcs = node_pairs[:, 0].repeat_interleave(num_negs)
    neg_dsts = torch.arange(2 * num_ids, 2 * num_ids + num_ids * num_negs)
    neg_node_pairs = torch.cat((neg_srcs, neg_dsts)).reshape(2, -1).T
    labels = torch.empty(num_ids * 3)
    labels[:num_ids] = 1
    labels[num_ids:] = 0
    indexes = torch.cat(
        (
            torch.arange(0, num_ids),
            torch.arange(0, num_ids).repeat_interleave(num_negs),
        )
    )
    node_pairs = torch.cat((node_pairs, neg_node_pairs))
355
    item_set = gb.ItemSet(
356
        (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
357
    )
358
    item_sampler = gb.ItemSampler(
359
360
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
361
362
    src_ids = []
    dst_ids = []
363
    negs_ids = []
364
365
    final_labels = []
    final_indexes = []
366
    for i, minibatch in enumerate(item_sampler):
367
368
369
370
371
372
373
        assert minibatch.seeds is not None
        assert isinstance(minibatch.seeds, torch.Tensor)
        assert minibatch.labels is not None
        assert minibatch.indexes is not None
        src, dst = minibatch.seeds.T
        negs_src = src[~minibatch.labels.to(bool)]
        negs_dst = dst[~minibatch.labels.to(bool)]
374
375
        is_last = (i + 1) * batch_size >= num_ids * 3
        if not is_last or num_ids * 3 % batch_size == 0:
376
377
378
            expected_batch_size = batch_size
        else:
            if not drop_last:
379
                expected_batch_size = num_ids * 3 % batch_size
380
381
            else:
                assert False
382
383
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
384
385
386
        assert negs_src.dim() == 1
        assert negs_dst.dim() == 1
        assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
387
        # Archive batch.
388
389
        src_ids.append(src)
        dst_ids.append(dst)
390
        negs_ids.append(negs_dst)
391
392
        final_labels.append(minibatch.labels)
        final_indexes.append(minibatch.indexes)
393
394
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
395
    negs_ids = torch.cat(negs_ids)
396
397
    final_labels = torch.cat(final_labels)
    final_indexes = torch.cat(final_indexes)
398
399
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
400
    assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
401
402
403
404
    assert torch.all(final_labels[:-1] >= final_labels[1:]) is not shuffle
    if not drop_last:
        assert final_labels.sum() == num_ids
        assert torch.equal(final_indexes, indexes) is not shuffle
405
406


407
408
409
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
410
def test_ItemSet_hyperlink(batch_size, shuffle, drop_last):
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    # Node pairs.
    num_ids = 103
    seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
    item_set = gb.ItemSet(seeds, names="seeds")
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    seeds_ids = []
    for i, minibatch in enumerate(item_sampler):
        assert minibatch.seeds is not None
        assert isinstance(minibatch.seeds, torch.Tensor)
        assert minibatch.labels is None
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert minibatch.seeds.shape == (expected_batch_size, 3)
        # Verify seeds match.
        assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])
        assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])
        # Archive batch.
        seeds_ids.append(minibatch.seeds)
    seeds_ids = torch.cat(seeds_ids)
    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
    # Node pairs and labels
    num_ids = 103
    seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
    labels = seeds[:, 0]
    item_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    seeds_ids = []
    labels = []
    for i, minibatch in enumerate(item_sampler):
        assert minibatch.seeds is not None
        assert isinstance(minibatch.seeds, torch.Tensor)
        assert minibatch.labels is not None
        label = minibatch.labels
        assert len(minibatch.seeds) == len(label)
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert minibatch.seeds.shape == (expected_batch_size, 3)
        assert len(label) == expected_batch_size
        # Verify seeds and labels match.
        assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])
        assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])
        # Archive batch.
        seeds_ids.append(minibatch.seeds)
        labels.append(label)
    seeds_ids = torch.cat(seeds_ids)
    labels = torch.cat(labels)
    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


487
488
489
def test_append_with_other_datapipes():
    num_ids = 100
    batch_size = 4
490
    item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
491
    data_pipe = gb.ItemSampler(item_set, batch_size)
492
493
494
495
    # torchdata.datapipes.iter.Enumerator
    data_pipe = data_pipe.enumerate()
    for i, (idx, data) in enumerate(data_pipe):
        assert i == idx
496
        assert len(data.seeds) == batch_size
497
498


499
500
501
502
503
504
505
506
507
508
509
510
511
512
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
    class IterableOnly:
        def __init__(self, start, stop):
            self._start = start
            self._stop = stop

        def __iter__(self):
            return iter(torch.arange(self._start, self._stop))

    num_ids = 205
    ids = {
513
514
        "user": gb.ItemSet(IterableOnly(0, 99), names="seeds"),
        "item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"),
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids)
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    for i, minibatch in enumerate(item_sampler):
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert isinstance(minibatch, gb.MiniBatch)
534
        assert minibatch.seeds is not None
535
        ids = []
536
        for _, v in minibatch.seeds.items():
537
538
539
540
541
542
543
544
            ids.append(v)
        ids = torch.cat(ids)
        assert len(ids) == expected_batch_size
        minibatch_ids.append(ids)
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


545
546
547
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
548
549
550
551
def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
    # Node IDs.
    num_ids = 205
    ids = {
552
553
        "user": gb.ItemSet(torch.arange(0, 99), names="seeds"),
        "item": gb.ItemSet(torch.arange(99, num_ids), names="seeds"),
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids)
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    for i, minibatch in enumerate(item_sampler):
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert isinstance(minibatch, gb.MiniBatch)
573
        assert minibatch.seeds is not None
574
        ids = []
575
        for _, v in minibatch.seeds.items():
576
577
578
579
580
581
582
583
584
585
586
587
            ids.append(v)
        ids = torch.cat(ids)
        assert len(ids) == expected_batch_size
        minibatch_ids.append(ids)
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
588
589
590
    # Node IDs.
    num_ids = 205
    ids = {
591
592
        "user": gb.ItemSet(
            (torch.arange(0, 99), torch.arange(0, 99)),
593
            names=("seeds", "labels"),
594
595
596
        ),
        "item": gb.ItemSet(
            (torch.arange(99, num_ids), torch.arange(99, num_ids)),
597
            names=("seeds", "labels"),
598
        ),
599
600
601
602
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
603
    item_set = gb.ItemSetDict(ids)
604
    item_sampler = gb.ItemSampler(
605
606
607
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
608
609
610
    minibatch_labels = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
611
        assert minibatch.seeds is not None
612
        assert minibatch.labels is not None
613
614
615
616
617
618
619
620
621
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        ids = []
622
        for _, v in minibatch.seeds.items():
623
624
625
626
            ids.append(v)
        ids = torch.cat(ids)
        assert len(ids) == expected_batch_size
        minibatch_ids.append(ids)
627
628
629
630
631
632
        labels = []
        for _, v in minibatch.labels.items():
            labels.append(v)
        labels = torch.cat(labels)
        assert len(labels) == expected_batch_size
        minibatch_labels.append(labels)
633
    minibatch_ids = torch.cat(minibatch_ids)
634
    minibatch_labels = torch.cat(minibatch_labels)
635
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
636
637
638
    assert (
        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
    )
639
640
641
642
643


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
644
def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
645
646
    # Node pairs.
    num_ids = 103
647
648
649
    total_pairs = 2 * num_ids
    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
650
    node_pairs_dict = {
651
652
        "user:like:item": gb.ItemSet(node_pairs_like, names="seeds"),
        "user:follow:user": gb.ItemSet(node_pairs_follow, names="seeds"),
653
    }
654
    item_set = gb.ItemSetDict(node_pairs_dict)
655
    item_sampler = gb.ItemSampler(
656
657
658
659
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
660
661
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
662
        assert minibatch.seeds is not None
663
664
665
        assert minibatch.labels is None
        is_last = (i + 1) * batch_size >= total_pairs
        if not is_last or total_pairs % batch_size == 0:
666
667
668
            expected_batch_size = batch_size
        else:
            if not drop_last:
669
                expected_batch_size = total_pairs % batch_size
670
671
672
673
            else:
                assert False
        src = []
        dst = []
674
675
676
677
        for _, (seeds) in minibatch.seeds.items():
            assert isinstance(seeds, torch.Tensor)
            src.append(seeds[:, 0])
            dst.append(seeds[:, 1])
678
679
680
681
682
683
        src = torch.cat(src)
        dst = torch.cat(dst)
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        src_ids.append(src)
        dst_ids.append(dst)
684
        assert torch.equal(src + 1, dst)
685
686
687
688
689
690
691
692
693
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
694
def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
695
696
697
    # Node pairs and labels
    num_ids = 103
    total_ids = 2 * num_ids
698
699
    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
700
701
    labels = torch.arange(0, num_ids)
    node_pairs_dict = {
702
        "user:like:item": gb.ItemSet(
703
            (node_pairs_like, node_pairs_like[:, 0]),
704
            names=("seeds", "labels"),
705
        ),
706
        "user:follow:user": gb.ItemSet(
707
            (node_pairs_follow, node_pairs_follow[:, 0]),
708
            names=("seeds", "labels"),
709
710
        ),
    }
711
    item_set = gb.ItemSetDict(node_pairs_dict)
712
    item_sampler = gb.ItemSampler(
713
714
715
716
717
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
    labels = []
718
719
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
720
        assert minibatch.seeds is not None
721
        assert minibatch.labels is not None
722
723
724
725
726
727
728
729
730
731
732
        is_last = (i + 1) * batch_size >= total_ids
        if not is_last or total_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_ids % batch_size
            else:
                assert False
        src = []
        dst = []
        label = []
733
734
735
736
        for _, seeds in minibatch.seeds.items():
            assert isinstance(seeds, torch.Tensor)
            src.append(seeds[:, 0])
            dst.append(seeds[:, 1])
737
        for _, v_label in minibatch.labels.items():
738
739
740
741
742
743
744
745
746
747
            label.append(v_label)
        src = torch.cat(src)
        dst = torch.cat(dst)
        label = torch.cat(label)
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        assert len(label) == expected_batch_size
        src_ids.append(src)
        dst_ids.append(dst)
        labels.append(label)
748
        assert torch.equal(src + 1, dst)
749
750
751
752
753
754
755
756
757
758
759
760
        assert torch.equal(src, label)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    labels = torch.cat(labels)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
761
def test_ItemSetDict_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
762
763
    # Head, tail and negative tails.
    num_ids = 103
764
    total_ids = 6 * num_ids
765
    num_negs = 2
766
    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
767
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    neg_dsts_like = torch.arange(num_ids * 4, num_ids * 4 + num_ids * num_negs)
    neg_node_pairs_like = (
        torch.cat(
            (node_pairs_like[:, 0].repeat_interleave(num_negs), neg_dsts_like)
        )
        .view(2, -1)
        .T
    )
    all_node_pairs_like = torch.cat((node_pairs_like, neg_node_pairs_like))
    labels_like = torch.empty(num_ids * 3)
    labels_like[:num_ids] = 1
    labels_like[num_ids:] = 0
    indexes_like = torch.cat(
        (
            torch.arange(0, num_ids),
            torch.arange(0, num_ids).repeat_interleave(num_negs),
        )
    )
786
787
    neg_dsts_follow = torch.arange(
        num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
    )
    neg_node_pairs_follow = (
        torch.cat(
            (
                node_pairs_follow[:, 0].repeat_interleave(num_negs),
                neg_dsts_follow,
            )
        )
        .view(2, -1)
        .T
    )
    all_node_pairs_follow = torch.cat(
        (node_pairs_follow, neg_node_pairs_follow)
    )
    labels_follow = torch.empty(num_ids * 3)
    labels_follow[:num_ids] = 1
    labels_follow[num_ids:] = 0
    indexes_follow = torch.cat(
        (
            torch.arange(0, num_ids),
            torch.arange(0, num_ids).repeat_interleave(num_negs),
        )
    )
811
    data_dict = {
812
        "user:like:item": gb.ItemSet(
813
814
            (all_node_pairs_like, labels_like, indexes_like),
            names=("seeds", "labels", "indexes"),
815
816
        ),
        "user:follow:user": gb.ItemSet(
817
818
            (all_node_pairs_follow, labels_follow, indexes_follow),
            names=("seeds", "labels", "indexes"),
819
        ),
820
    }
821
    item_set = gb.ItemSetDict(data_dict)
822
    item_sampler = gb.ItemSampler(
823
824
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
825
826
    src_ids = []
    dst_ids = []
827
    negs_ids = []
828
829
    final_labels = defaultdict(list)
    final_indexes = defaultdict(list)
830
831
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
832
833
        assert minibatch.seeds is not None
        assert minibatch.labels is not None
834
        assert minibatch.indexes is not None
835
836
837
838
839
840
841
842
        is_last = (i + 1) * batch_size >= total_ids
        if not is_last or total_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_ids % batch_size
            else:
                assert False
843
844
        src = []
        dst = []
845
846
847
848
849
850
        negs_src = []
        negs_dst = []
        for etype, seeds in minibatch.seeds.items():
            assert isinstance(seeds, torch.Tensor)
            src_etype = seeds[:, 0]
            dst_etype = seeds[:, 1]
851
852
            src.append(src_etype)
            dst.append(dst_etype)
853
854
            negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
            negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
855
856
            final_labels[etype].append(minibatch.labels[etype])
            final_indexes[etype].append(minibatch.indexes[etype])
857
858
        src = torch.cat(src)
        dst = torch.cat(dst)
859
860
        negs_src = torch.cat(negs_src)
        negs_dst = torch.cat(negs_dst)
861
862
863
864
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        src_ids.append(src)
        dst_ids.append(dst)
865
866
867
868
        negs_ids.append(negs_dst)
        assert negs_src.dim() == 1
        assert negs_dst.dim() == 1
        assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
869
870
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
871
    negs_ids = torch.cat(negs_ids)
872
873
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
874
    assert torch.all(negs_ids <= negs_ids) is not shuffle
875
876
877
878
879
880
881
882
883
884
885
886
    for etype in data_dict.keys():
        final_labels_etype = torch.cat(final_labels[etype])
        final_indexes_etype = torch.cat(final_indexes[etype])
        assert (
            torch.all(final_labels_etype[:-1] >= final_labels_etype[1:])
            is not shuffle
        )
        if not drop_last:
            assert final_labels_etype.sum() == num_ids
            assert (
                torch.equal(final_indexes_etype, indexes_follow) is not shuffle
            )
887
888


889
890
891
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
892
def test_ItemSetDict_hyperlink(batch_size, shuffle, drop_last):
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    # Node pairs.
    num_ids = 103
    total_pairs = 2 * num_ids
    seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)
    seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)
    seeds_dict = {
        "user:like:item": gb.ItemSet(seeds_like, names="seeds"),
        "user:follow:user": gb.ItemSet(seeds_follow, names="seeds"),
    }
    item_set = gb.ItemSetDict(seeds_dict)
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    seeds_ids = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seeds is not None
        assert minibatch.labels is None
911
        assert minibatch.indexes is None
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        is_last = (i + 1) * batch_size >= total_pairs
        if not is_last or total_pairs % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_pairs % batch_size
            else:
                assert False
        seeds_lst = []
        for _, (seeds) in minibatch.seeds.items():
            assert isinstance(seeds, torch.Tensor)
            seeds_lst.append(seeds)
        seeds_lst = torch.cat(seeds_lst)
        assert seeds_lst.shape == (expected_batch_size, 3)
        seeds_ids.append(seeds_lst)
        assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])
        assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])
    seeds_ids = torch.cat(seeds_ids)
    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
938
def test_ItemSetDict_hyperlink_labels(batch_size, shuffle, drop_last):
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
    # Node pairs and labels
    num_ids = 103
    total_ids = 2 * num_ids
    seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)
    seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)
    seeds_dict = {
        "user:like:item": gb.ItemSet(
            (seeds_like, seeds_like[:, 0]),
            names=("seeds", "labels"),
        ),
        "user:follow:user": gb.ItemSet(
            (seeds_follow, seeds_follow[:, 0]),
            names=("seeds", "labels"),
        ),
    }
    item_set = gb.ItemSetDict(seeds_dict)
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    seeds_ids = []
    labels = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seeds is not None
        assert minibatch.labels is not None
964
        assert minibatch.indexes is None
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
        is_last = (i + 1) * batch_size >= total_ids
        if not is_last or total_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_ids % batch_size
            else:
                assert False
        seeds_lst = []
        label = []
        for _, seeds in minibatch.seeds.items():
            assert isinstance(seeds, torch.Tensor)
            seeds_lst.append(seeds)
        for _, v_label in minibatch.labels.items():
            label.append(v_label)
        seeds_lst = torch.cat(seeds_lst)
        label = torch.cat(label)
        assert seeds_lst.shape == (expected_batch_size, 3)
        assert len(label) == expected_batch_size
        seeds_ids.append(seeds_lst)
        labels.append(label)
        assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])
        assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])
        assert torch.equal(seeds_lst[:, 0], label)
    seeds_ids = torch.cat(seeds_ids)
    labels = torch.cat(labels)
    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


997
998
999
1000
1001
def distributed_item_sampler_subprocess(
    proc_id,
    nprocs,
    item_set,
    num_ids,
1002
    num_workers,
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    batch_size,
    drop_last,
    drop_uneven_inputs,
):
    # On Windows, the init method can only be file.
    init_method = (
        f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}"
        if platform == "win32"
        else "tcp://127.0.0.1:12345"
    )
    dist.init_process_group(
        backend="gloo",  # Use Gloo backend for CPU multiprocessing
        init_method=init_method,
        world_size=nprocs,
        rank=proc_id,
    )

    # Create a DistributedItemSampler.
    item_sampler = gb.DistributedItemSampler(
        item_set,
        batch_size=batch_size,
1024
        shuffle=True,
1025
1026
1027
1028
1029
1030
1031
1032
        drop_last=drop_last,
        drop_uneven_inputs=drop_uneven_inputs,
    )
    feature_fetcher = gb.FeatureFetcher(
        item_sampler,
        gb.BasicFeatureStore({}),
        [],
    )
1033
    data_loader = gb.DataLoader(feature_fetcher, num_workers=num_workers)
1034
1035
1036
1037
1038
1039

    # Count the numbers of items and batches.
    num_items = 0
    sampled_count = torch.zeros(num_ids, dtype=torch.int32)
    for i in data_loader:
        # Count how many times each item is sampled.
1040
        sampled_count[i.seeds] += 1
1041
        if drop_last:
1042
1043
            assert i.seeds.size(0) == batch_size
        num_items += i.seeds.size(0)
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
    num_batches = len(list(item_sampler))

    if drop_uneven_inputs:
        num_batches_tensor = torch.tensor(num_batches)
        dist.broadcast(num_batches_tensor, 0)
        # Test if the number of batches are the same for all processes.
        assert num_batches_tensor == num_batches

    # Add up results from all processes.
    dist.reduce(sampled_count, 0)

    try:
        # Make sure no item is sampled more than once.
        assert sampled_count.max() <= 1
    finally:
        dist.destroy_process_group()


1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
@pytest.mark.parametrize(
    "params",
    [
        ((24, 4, 0, 4, False, False), [(8, 8), (8, 8), (4, 4), (4, 4)]),
        ((30, 4, 0, 4, False, False), [(8, 8), (8, 8), (8, 8), (6, 6)]),
        ((30, 4, 0, 4, True, False), [(8, 8), (8, 8), (8, 8), (6, 4)]),
        ((30, 4, 0, 4, False, True), [(8, 8), (8, 8), (8, 8), (6, 6)]),
        ((30, 4, 0, 4, True, True), [(8, 4), (8, 4), (8, 4), (6, 4)]),
        (
            (53, 4, 2, 4, False, False),
            [(8, 8), (8, 8), (8, 8), (5, 5), (8, 8), (4, 4), (8, 8), (4, 4)],
        ),
        (
            (53, 4, 2, 4, True, False),
            [(8, 8), (8, 8), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
        ),
        (
            (53, 4, 2, 4, False, True),
            [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
        ),
        (
            (53, 4, 2, 4, True, True),
            [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],
        ),
        (
            (63, 4, 2, 4, False, False),
            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],
        ),
        (
            (63, 4, 2, 4, True, False),
            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (10, 8), (5, 4)],
        ),
        (
            (63, 4, 2, 4, False, True),
            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],
        ),
        (
            (63, 4, 2, 4, True, True),
            [
                (10, 8),
                (6, 4),
                (10, 8),
                (6, 4),
                (10, 8),
                (6, 4),
                (10, 8),
                (5, 4),
            ],
        ),
        (
            (65, 4, 2, 4, False, False),
            [(9, 9), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],
        ),
        (
            (65, 4, 2, 4, True, True),
            [(9, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],
        ),
    ],
)
def test_RangeCalculation(params):
    (
        (
            total,
            num_replicas,
            num_workers,
            batch_size,
            drop_last,
            drop_uneven_inputs,
        ),
        key,
    ) = params
    answer = []
    sum = 0
    for rank in range(num_replicas):
        for worker_id in range(max(num_workers, 1)):
1137
            result = gb.internal.calculate_range(
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
                True,
                total,
                num_replicas,
                rank,
                num_workers,
                worker_id,
                batch_size,
                drop_last,
                drop_uneven_inputs,
            )
            assert sum == result[0]
            sum += result[1]
            answer.append((result[1], result[2]))
    assert key == answer


1154
@unittest.skipIf(F._default_context_str != "cpu", reason="GPU not required.")
1155
@pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36])
1156
@pytest.mark.parametrize("num_workers", [0, 2])
1157
1158
1159
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("drop_uneven_inputs", [False, True])
def test_DistributedItemSampler(
1160
    num_ids, num_workers, drop_last, drop_uneven_inputs
1161
1162
1163
):
    nprocs = 4
    batch_size = 4
1164
    item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

    # On Windows, if the process group initialization file already exists,
    # the program may hang. So we need to delete it if it exists.
    if platform == "win32":
        try:
            os.remove(os.path.join(os.getcwd(), "dis_tempfile"))
        except FileNotFoundError:
            pass

    mp.spawn(
        distributed_item_sampler_subprocess,
        args=(
            nprocs,
            item_set,
            num_ids,
1180
            num_workers,
1181
1182
1183
1184
1185
1186
1187
            batch_size,
            drop_last,
            drop_uneven_inputs,
        ),
        nprocs=nprocs,
        join=True,
    )