"docs/source/api/vscode:/vscode.git/clone" did not exist on "91fe0c90690d9a7078b0b03dc059088a6f310777"
test_item_sampler.py 26.7 KB
Newer Older
1
import os
2
import re
3
from sys import platform
4

5
6
7
import dgl
import pytest
import torch
8
9
import torch.distributed as dist
import torch.multiprocessing as mp
10
11
12
13
from dgl import graphbolt as gb
from torch.testing import assert_close


14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def test_ItemSampler_minibatcher():
    # Default minibatcher is used if not specified.
    # Warning message is raised if names are not specified.
    item_set = gb.ItemSet(torch.arange(0, 10))
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    with pytest.warns(
        UserWarning,
        match=re.escape(
            "Failed to map item list to `MiniBatch` as the names of items are "
            "not provided. Please provide a customized `MiniBatcher`. The "
            "item list is returned as is."
        ),
    ):
        minibatch = next(iter(item_sampler))
        assert not isinstance(minibatch, gb.MiniBatch)

    # Default minibatcher is used if not specified.
    # Warning message is raised if unrecognized names are specified.
    item_set = gb.ItemSet(torch.arange(0, 10), names="unknown_name")
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    with pytest.warns(
        UserWarning,
        match=re.escape(
            "Unknown item name 'unknown_name' is detected and added into "
            "`MiniBatch`. You probably need to provide a customized "
            "`MiniBatcher`."
        ),
    ):
        minibatch = next(iter(item_sampler))
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.unknown_name is not None

    # Default minibatcher is used if not specified.
    # `MiniBatch` is returned if expected names are specified.
    item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
    item_sampler = gb.ItemSampler(item_set, batch_size=4)
    minibatch = next(iter(item_sampler))
    assert isinstance(minibatch, gb.MiniBatch)
    assert minibatch.seed_nodes is not None
    assert len(minibatch.seed_nodes) == 4

    # Customized minibatcher is used if specified.
    def minibatcher(batch, names):
        return gb.MiniBatch(seed_nodes=batch)

    item_sampler = gb.ItemSampler(
        item_set, batch_size=4, minibatcher=minibatcher
    )
    minibatch = next(iter(item_sampler))
    assert isinstance(minibatch, gb.MiniBatch)
    assert minibatch.seed_nodes is not None
    assert len(minibatch.seed_nodes) == 4


68
69
70
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
71
def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
72
    # Node IDs.
73
    num_ids = 103
74
75
    seed_nodes = torch.arange(0, num_ids)
    item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
76
    item_sampler = gb.ItemSampler(
77
78
79
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
80
    for i, minibatch in enumerate(item_sampler):
81
82
83
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seed_nodes is not None
        assert minibatch.labels is None
84
85
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
86
            assert len(minibatch.seed_nodes) == batch_size
87
88
        else:
            if not drop_last:
89
                assert len(minibatch.seed_nodes) == num_ids % batch_size
90
91
            else:
                assert False
92
        minibatch_ids.append(minibatch.seed_nodes)
93
94
95
96
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
    # Node IDs.
    num_ids = 103
    seed_nodes = torch.arange(0, num_ids)
    labels = torch.arange(0, num_ids)
    item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    minibatch_labels = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seed_nodes is not None
        assert minibatch.labels is not None
        assert len(minibatch.seed_nodes) == len(minibatch.labels)
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            assert len(minibatch.seed_nodes) == batch_size
        else:
            if not drop_last:
                assert len(minibatch.seed_nodes) == num_ids % batch_size
            else:
                assert False
        minibatch_ids.append(minibatch.seed_nodes)
        minibatch_labels.append(minibatch.labels)
    minibatch_ids = torch.cat(minibatch_ids)
    minibatch_labels = torch.cat(minibatch_labels)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
    assert (
        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
    )


134
135
136
137
138
139
140
141
142
143
144
145
146
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_graphs(batch_size, shuffle, drop_last):
    # Graphs.
    num_graphs = 103
    num_nodes = 10
    num_edges = 20
    graphs = [
        dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1))
        for i in range(num_graphs)
    ]
    item_set = gb.ItemSet(graphs)
147
    item_sampler = gb.ItemSampler(
148
149
150
151
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_num_nodes = []
    minibatch_num_edges = []
152
    for i, minibatch in enumerate(item_sampler):
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        is_last = (i + 1) * batch_size >= num_graphs
        if not is_last or num_graphs % batch_size == 0:
            assert minibatch.batch_size == batch_size
        else:
            if not drop_last:
                assert minibatch.batch_size == num_graphs % batch_size
            else:
                assert False
        minibatch_num_nodes.append(minibatch.batch_num_nodes())
        minibatch_num_edges.append(minibatch.batch_num_edges())
    minibatch_num_nodes = torch.cat(minibatch_num_nodes)
    minibatch_num_edges = torch.cat(minibatch_num_edges)
    assert (
        torch.all(minibatch_num_nodes[:-1] <= minibatch_num_nodes[1:])
        is not shuffle
    )
    assert (
        torch.all(minibatch_num_edges[:-1] <= minibatch_num_edges[1:])
        is not shuffle
    )


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
    # Node pairs.
    num_ids = 103
181
182
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
    item_set = gb.ItemSet(node_pairs, names="node_pairs")
183
    item_sampler = gb.ItemSampler(
184
185
186
187
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
188
189
    for i, minibatch in enumerate(item_sampler):
        assert minibatch.node_pairs is not None
190
        assert isinstance(minibatch.node_pairs, tuple)
191
        assert minibatch.labels is None
192
        src, dst = minibatch.node_pairs
193
194
195
196
197
198
199
200
201
202
203
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        # Verify src and dst IDs match.
204
        assert torch.equal(src + 1, dst)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        # Archive batch.
        src_ids.append(src)
        dst_ids.append(dst)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
    # Node pairs and labels
    num_ids = 103
220
221
222
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
    labels = node_pairs[:, 0]
    item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
223
    item_sampler = gb.ItemSampler(
224
225
226
227
228
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
    labels = []
229
230
    for i, minibatch in enumerate(item_sampler):
        assert minibatch.node_pairs is not None
231
        assert isinstance(minibatch.node_pairs, tuple)
232
        assert minibatch.labels is not None
233
        src, dst = minibatch.node_pairs
234
        label = minibatch.labels
235
236
        assert len(src) == len(dst)
        assert len(src) == len(label)
237
238
239
240
241
242
243
244
245
246
247
248
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        assert len(label) == expected_batch_size
        # Verify src/dst IDs and labels match.
249
        assert torch.equal(src + 1, dst)
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        assert torch.equal(src, label)
        # Archive batch.
        src_ids.append(src)
        dst_ids.append(dst)
        labels.append(label)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    labels = torch.cat(labels)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
266
267
def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
    # Node pairs and negative destinations.
268
269
    num_ids = 103
    num_negs = 2
270
271
272
273
274
275
276
    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
    neg_dsts = torch.arange(
        2 * num_ids, 2 * num_ids + num_ids * num_negs
    ).reshape(-1, num_negs)
    item_set = gb.ItemSet(
        (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
    )
277
    item_sampler = gb.ItemSampler(
278
279
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
280
281
    src_ids = []
    dst_ids = []
282
    negs_ids = []
283
284
    for i, minibatch in enumerate(item_sampler):
        assert minibatch.node_pairs is not None
285
        assert isinstance(minibatch.node_pairs, tuple)
286
        assert minibatch.negative_dsts is not None
287
        src, dst = minibatch.node_pairs
288
        negs = minibatch.negative_dsts
289
290
291
292
293
294
295
296
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
297
298
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
299
300
301
        assert negs.dim() == 2
        assert negs.shape[0] == expected_batch_size
        assert negs.shape[1] == num_negs
302
303
304
        # Verify node pairs and negative destinations.
        assert torch.equal(src + 1, dst)
        assert torch.equal(negs[:, 0] + 1, negs[:, 1])
305
        # Archive batch.
306
307
        src_ids.append(src)
        dst_ids.append(dst)
308
        negs_ids.append(negs)
309
310
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
311
    negs_ids = torch.cat(negs_ids)
312
313
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
314
315
316
317
318
319
320
321
    assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle
    assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle


def test_append_with_other_datapipes():
    num_ids = 100
    batch_size = 4
    item_set = gb.ItemSet(torch.arange(0, num_ids))
322
    data_pipe = gb.ItemSampler(item_set, batch_size)
323
324
325
326
327
    # torchdata.datapipes.iter.Enumerator
    data_pipe = data_pipe.enumerate()
    for i, (idx, data) in enumerate(data_pipe):
        assert i == idx
        assert len(data) == batch_size
328
329
330
331
332


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
    # Node IDs.
    num_ids = 205
    ids = {
        "user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"),
        "item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"),
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids)
    item_sampler = gb.ItemSampler(
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
    for i, minibatch in enumerate(item_sampler):
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seed_nodes is not None
        ids = []
        for _, v in minibatch.seed_nodes.items():
            ids.append(v)
        ids = torch.cat(ids)
        assert len(ids) == expected_batch_size
        minibatch_ids.append(ids)
    minibatch_ids = torch.cat(minibatch_ids)
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
373
374
375
    # Node IDs.
    num_ids = 205
    ids = {
376
377
378
379
380
381
382
383
        "user": gb.ItemSet(
            (torch.arange(0, 99), torch.arange(0, 99)),
            names=("seed_nodes", "labels"),
        ),
        "item": gb.ItemSet(
            (torch.arange(99, num_ids), torch.arange(99, num_ids)),
            names=("seed_nodes", "labels"),
        ),
384
385
386
387
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
388
    item_set = gb.ItemSetDict(ids)
389
    item_sampler = gb.ItemSampler(
390
391
392
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    minibatch_ids = []
393
394
395
396
397
    minibatch_labels = []
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.seed_nodes is not None
        assert minibatch.labels is not None
398
399
400
401
402
403
404
405
406
        is_last = (i + 1) * batch_size >= num_ids
        if not is_last or num_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = num_ids % batch_size
            else:
                assert False
        ids = []
407
        for _, v in minibatch.seed_nodes.items():
408
409
410
411
            ids.append(v)
        ids = torch.cat(ids)
        assert len(ids) == expected_batch_size
        minibatch_ids.append(ids)
412
413
414
415
416
417
        labels = []
        for _, v in minibatch.labels.items():
            labels.append(v)
        labels = torch.cat(labels)
        assert len(labels) == expected_batch_size
        minibatch_labels.append(labels)
418
    minibatch_ids = torch.cat(minibatch_ids)
419
    minibatch_labels = torch.cat(minibatch_labels)
420
    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
421
422
423
    assert (
        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
    )
424
425
426
427
428


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
429
def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
430
431
    # Node pairs.
    num_ids = 103
432
433
434
    total_pairs = 2 * num_ids
    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
435
    node_pairs_dict = {
436
437
        "user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"),
        "user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"),
438
    }
439
    item_set = gb.ItemSetDict(node_pairs_dict)
440
    item_sampler = gb.ItemSampler(
441
442
443
444
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
445
446
447
448
449
450
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.node_pairs is not None
        assert minibatch.labels is None
        is_last = (i + 1) * batch_size >= total_pairs
        if not is_last or total_pairs % batch_size == 0:
451
452
453
            expected_batch_size = batch_size
        else:
            if not drop_last:
454
                expected_batch_size = total_pairs % batch_size
455
456
457
458
            else:
                assert False
        src = []
        dst = []
459
460
461
462
        for _, (node_pairs) in minibatch.node_pairs.items():
            assert isinstance(node_pairs, tuple)
            src.append(node_pairs[0])
            dst.append(node_pairs[1])
463
464
465
466
467
468
        src = torch.cat(src)
        dst = torch.cat(dst)
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        src_ids.append(src)
        dst_ids.append(dst)
469
        assert torch.equal(src + 1, dst)
470
471
472
473
474
475
476
477
478
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
479
def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
480
481
482
    # Node pairs and labels
    num_ids = 103
    total_ids = 2 * num_ids
483
484
    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
485
486
    labels = torch.arange(0, num_ids)
    node_pairs_dict = {
487
        "user:like:item": gb.ItemSet(
488
489
            (node_pairs_like, node_pairs_like[:, 0]),
            names=("node_pairs", "labels"),
490
        ),
491
        "user:follow:user": gb.ItemSet(
492
493
            (node_pairs_follow, node_pairs_follow[:, 0]),
            names=("node_pairs", "labels"),
494
495
        ),
    }
496
    item_set = gb.ItemSetDict(node_pairs_dict)
497
    item_sampler = gb.ItemSampler(
498
499
500
501
502
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
    src_ids = []
    dst_ids = []
    labels = []
503
504
505
506
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.node_pairs is not None
        assert minibatch.labels is not None
507
508
509
510
511
512
513
514
515
516
517
        is_last = (i + 1) * batch_size >= total_ids
        if not is_last or total_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_ids % batch_size
            else:
                assert False
        src = []
        dst = []
        label = []
518
        for _, node_pairs in minibatch.node_pairs.items():
519
520
521
            assert isinstance(node_pairs, tuple)
            src.append(node_pairs[0])
            dst.append(node_pairs[1])
522
        for _, v_label in minibatch.labels.items():
523
524
525
526
527
528
529
530
531
532
            label.append(v_label)
        src = torch.cat(src)
        dst = torch.cat(dst)
        label = torch.cat(label)
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
        assert len(label) == expected_batch_size
        src_ids.append(src)
        dst_ids.append(dst)
        labels.append(label)
533
        assert torch.equal(src + 1, dst)
534
535
536
537
538
539
540
541
542
543
544
545
        assert torch.equal(src, label)
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
    labels = torch.cat(labels)
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
546
def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
547
548
549
550
    # Head, tail and negative tails.
    num_ids = 103
    total_ids = 2 * num_ids
    num_negs = 2
551
552
553
554
555
556
557
558
    node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
    neg_dsts_like = torch.arange(
        num_ids * 4, num_ids * 4 + num_ids * num_negs
    ).reshape(-1, num_negs)
    neg_dsts_follow = torch.arange(
        num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
    ).reshape(-1, num_negs)
559
    data_dict = {
560
561
562
563
564
565
566
567
        "user:like:item": gb.ItemSet(
            (node_paris_like, neg_dsts_like),
            names=("node_pairs", "negative_dsts"),
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs_follow, neg_dsts_follow),
            names=("node_pairs", "negative_dsts"),
        ),
568
    }
569
    item_set = gb.ItemSetDict(data_dict)
570
    item_sampler = gb.ItemSampler(
571
572
        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )
573
574
    src_ids = []
    dst_ids = []
575
    negs_ids = []
576
577
578
579
    for i, minibatch in enumerate(item_sampler):
        assert isinstance(minibatch, gb.MiniBatch)
        assert minibatch.node_pairs is not None
        assert minibatch.negative_dsts is not None
580
581
582
583
584
585
586
587
        is_last = (i + 1) * batch_size >= total_ids
        if not is_last or total_ids % batch_size == 0:
            expected_batch_size = batch_size
        else:
            if not drop_last:
                expected_batch_size = total_ids % batch_size
            else:
                assert False
588
589
        src = []
        dst = []
590
        negs = []
591
        for _, node_pairs in minibatch.node_pairs.items():
592
593
594
            assert isinstance(node_pairs, tuple)
            src.append(node_pairs[0])
            dst.append(node_pairs[1])
595
        for _, v_negs in minibatch.negative_dsts.items():
596
            negs.append(v_negs)
597
598
        src = torch.cat(src)
        dst = torch.cat(dst)
599
        negs = torch.cat(negs)
600
601
        assert len(src) == expected_batch_size
        assert len(dst) == expected_batch_size
602
        assert len(negs) == expected_batch_size
603
604
        src_ids.append(src)
        dst_ids.append(dst)
605
606
607
608
        negs_ids.append(negs)
        assert negs.dim() == 2
        assert negs.shape[0] == expected_batch_size
        assert negs.shape[1] == num_negs
609
610
611
612
        assert torch.equal(src + 1, dst)
        assert torch.equal(negs[:, 0] + 1, negs[:, 1])
    src_ids = torch.cat(src_ids)
    dst_ids = torch.cat(dst_ids)
613
    negs_ids = torch.cat(negs_ids)
614
615
    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
616
    assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739


def distributed_item_sampler_subprocess(
    proc_id,
    nprocs,
    item_set,
    num_ids,
    batch_size,
    shuffle,
    drop_last,
    drop_uneven_inputs,
):
    # On Windows, the init method can only be file.
    init_method = (
        f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}"
        if platform == "win32"
        else "tcp://127.0.0.1:12345"
    )
    dist.init_process_group(
        backend="gloo",  # Use Gloo backend for CPU multiprocessing
        init_method=init_method,
        world_size=nprocs,
        rank=proc_id,
    )

    # Create a DistributedItemSampler.
    item_sampler = gb.DistributedItemSampler(
        item_set,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        drop_uneven_inputs=drop_uneven_inputs,
    )
    feature_fetcher = gb.FeatureFetcher(
        item_sampler,
        gb.BasicFeatureStore({}),
        [],
    )
    data_loader = gb.SingleProcessDataLoader(feature_fetcher)

    # Count the numbers of items and batches.
    num_items = 0
    sampled_count = torch.zeros(num_ids, dtype=torch.int32)
    for i in data_loader:
        # Count how many times each item is sampled.
        sampled_count[i.seed_nodes] += 1
        num_items += i.seed_nodes.size(0)
    num_batches = len(list(item_sampler))

    # Calculate expected numbers of items and batches.
    expected_num_items = num_ids // nprocs + (num_ids % nprocs > proc_id)
    if drop_last and expected_num_items % batch_size > 0:
        expected_num_items -= expected_num_items % batch_size
    expected_num_batches = expected_num_items // batch_size + (
        (not drop_last) and (expected_num_items % batch_size > 0)
    )
    if drop_uneven_inputs:
        if (
            (not drop_last)
            and (num_ids % (nprocs * batch_size) < nprocs)
            and (num_ids % (nprocs * batch_size) > proc_id)
        ):
            expected_num_batches -= 1
            expected_num_items -= 1
        elif (
            drop_last
            and (nprocs * batch_size - num_ids % (nprocs * batch_size) < nprocs)
            and (num_ids % nprocs > proc_id)
        ):
            expected_num_batches -= 1
            expected_num_items -= batch_size
        num_batches_tensor = torch.tensor(num_batches)
        dist.broadcast(num_batches_tensor, 0)
        # Test if the number of batches are the same for all processes.
        assert num_batches_tensor == num_batches

    # Add up results from all processes.
    dist.reduce(sampled_count, 0)

    try:
        # Check if the numbers are as expected.
        assert num_items == expected_num_items
        assert num_batches == expected_num_batches

        # Make sure no item is sampled more than once.
        assert sampled_count.max() <= 1
    finally:
        dist.destroy_process_group()


@pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("drop_uneven_inputs", [False, True])
def test_DistributedItemSampler(
    num_ids, shuffle, drop_last, drop_uneven_inputs
):
    nprocs = 4
    batch_size = 4
    item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")

    # On Windows, if the process group initialization file already exists,
    # the program may hang. So we need to delete it if it exists.
    if platform == "win32":
        try:
            os.remove(os.path.join(os.getcwd(), "dis_tempfile"))
        except FileNotFoundError:
            pass

    mp.spawn(
        distributed_item_sampler_subprocess,
        args=(
            nprocs,
            item_set,
            num_ids,
            batch_size,
            shuffle,
            drop_last,
            drop_uneven_inputs,
        ),
        nprocs=nprocs,
        join=True,
    )