test_itemset.py 21.4 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
7
8
from dgl import graphbolt as gb


9
10
def test_ItemSet_names():
    # ItemSet with single name.
11
12
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
13
14
15

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
16
17
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
18
    )
19
    assert item_set.names == ("seed_nodes", "labels")
20

21
    # ItemSet without name.
22
23
24
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

25
26
27
    # Integer-initiated ItemSet with excessive names.
    with pytest.raises(
        AssertionError,
28
        match=re.escape("Number of items (1) and names (2) must match."),
29
30
31
    ):
        _ = gb.ItemSet(5, names=("seed_nodes", "labels"))

32
33
34
35
36
    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
37
38
39
        _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))


40
41
42
43
44
45
46
47
48
49
50
51
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype):
    item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes")
    for i, item in enumerate(item_set):
        assert i == item
        assert item.dtype == dtype
    assert item_set[2] == torch.tensor(2, dtype=dtype)
    assert torch.equal(
        item_set[slice(1, 4, 2)], torch.arange(1, 4, 2, dtype=dtype)
    )


52
def test_ItemSet_length():
53
54
55
56
57
58
59
60
    # Integer with valid length
    num = 10
    item_set = gb.ItemSet(num)
    assert len(item_set) == 10
    # Test __iter__() method. Same as below.
    for i, item in enumerate(item_set):
        assert i == item

61
62
63
64
    # Single iterable with valid length.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5
65
66
    for i, item in enumerate(item_set):
        assert i == item.item()
67
68
69
70

    # Tuple of iterables with valid length.
    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
    assert len(item_set) == 5
71
72
73
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1.item()
        assert i + 5 == item2.item()
74
75
76
77
78
79
80

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable with invalid length.
    item_set = gb.ItemSet(InvalidLength())
81
82
83
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
84
        _ = len(item_set)
85
86
87
88
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
89
90
    for i, item in enumerate(item_set):
        assert i == item
91
92
93

    # Tuple of iterables with invalid length.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
94
95
96
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
97
        _ = len(item_set)
98
99
100
101
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
102
103
104
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1
        assert i == item2
105
106


107
108
def test_ItemSet_seed_nodes():
    # Node IDs with tensor.
109
110
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
111
112
113
114
115
116
117
118
119
120
121
122
123
    # Iterating over ItemSet and indexing one by one.
    for i, item in enumerate(item_set):
        assert i == item.item()
        assert i == item_set[i]
    # Indexing with a slice.
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))

    # Node IDs with single integer.
    item_set = gb.ItemSet(5, names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
    # Iterating over ItemSet and indexing one by one.
124
125
    for i, item in enumerate(item_set):
        assert i == item.item()
126
127
        assert i == item_set[i]
    # Indexing with a slice.
128
129
130
131
132
133
134
135
136
137
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an integer.
    assert item_set[0] == 0
    assert item_set[-1] == 4
    # Indexing that is out of range.
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[5]
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[-10]
    # Indexing with tensor.
138
    with pytest.raises(
139
        TypeError, match="ItemSet indices must be integer or slice."
140
    ):
141
        _ = item_set[torch.arange(3)]
142
143


144
def test_ItemSet_seed_nodes_labels():
145
146
147
148
149
    # Node IDs and labels.
    seed_nodes = torch.arange(0, 5)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
    assert item_set.names == ("seed_nodes", "labels")
150
    # Iterating over ItemSet and indexing one by one.
151
152
153
    for i, (seed_node, label) in enumerate(item_set):
        assert seed_node == seed_nodes[i]
        assert label == labels[i]
154
155
156
157
158
159
160
161
        assert seed_node == item_set[i][0]
        assert label == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], seed_nodes)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], seed_nodes)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
162
163


164
def test_ItemSet_node_pairs():
165
166
167
168
    # Node pairs.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    item_set = gb.ItemSet(node_pairs, names="node_pairs")
    assert item_set.names == ("node_pairs",)
169
    # Iterating over ItemSet and indexing one by one.
170
171
172
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[i][0] == src
        assert node_pairs[i][1] == dst
173
174
175
176
177
178
        assert node_pairs[i][0] == item_set[i][0]
        assert node_pairs[i][1] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:], node_pairs)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], node_pairs)
179
180


181
def test_ItemSet_node_pairs_labels():
182
183
184
185
186
    # Node pairs and labels
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    assert item_set.names == ("node_pairs", "labels")
187
    # Iterating over ItemSet and indexing one by one.
188
189
190
    for i, (node_pair, label) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert labels[i] == label
191
192
193
194
195
196
197
198
        assert torch.equal(node_pairs[i], item_set[i][0])
        assert labels[i] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
199
200


201
def test_ItemSet_node_pairs_neg_dsts():
202
203
204
205
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    item_set = gb.ItemSet(
206
        (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
207
    )
208
    assert item_set.names == ("node_pairs", "negative_dsts")
209
    # Iterating over ItemSet and indexing one by one.
210
211
212
    for i, (node_pair, neg_dst) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert torch.equal(neg_dsts[i], neg_dst)
213
214
215
216
217
218
219
220
        assert torch.equal(node_pairs[i], item_set[i][0])
        assert torch.equal(neg_dsts[i], item_set[i][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
    assert torch.equal(item_set[:][1], neg_dsts)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
    assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts)
221
222


223
def test_ItemSet_graphs():
224
225
226
227
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
    item_set = gb.ItemSet(graphs)
    assert item_set.names is None
228
    # Iterating over ItemSet and indexing one by one.
229
230
    for i, item in enumerate(item_set):
        assert graphs[i] == item
231
232
233
        assert graphs[i] == item_set[i]
    # Indexing with a slice.
    assert item_set[:] == graphs
234
235
236
237
238
239


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
240
241
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
242
243
        }
    )
244
    assert item_set.names == ("seed_nodes",)
245
246
247
248
249
250

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
251
                names=("seed_nodes", "labels"),
252
253
254
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
255
                names=("seed_nodes", "labels"),
256
257
258
            ),
        }
    )
259
    assert item_set.names == ("seed_nodes", "labels")
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
279
                    names=("seed_nodes", "labels"),
280
281
                ),
                "item": gb.ItemSet(
282
                    (torch.arange(5, 10),), names=("seed_nodes",)
283
284
285
286
287
                ),
            }
        )


288
289
def test_ItemSetDict_length():
    # Single iterable with valid length.
290
291
292
293
294
295
296
297
298
299
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

300
301
302
303
304
    # Tuple of iterables with valid length.
    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
305
306
    item_set = gb.ItemSetDict(
        {
307
308
309
310
            "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
            "user:follow:user": gb.ItemSet(
                (node_pairs_follow, neg_dsts_follow)
            ),
311
312
        }
    )
313
    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
314
315
316
317
318

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

319
    # Single iterable with invalid length.
320
321
322
323
324
325
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
326
327
328
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
329
        _ = len(item_set)
330
331
332
333
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
334

335
    # Tuple of iterables with invalid length.
336
337
    item_set = gb.ItemSetDict(
        {
338
339
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
340
341
        }
    )
342
343
344
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
345
        _ = len(item_set)
346
347
348
349
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
Rhett Ying's avatar
Rhett Ying committed
350
351


352
353
354
355
def test_ItemSetDict_iteration_seed_nodes():
    # Node IDs.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(5, 10)
Rhett Ying's avatar
Rhett Ying committed
356
    ids = {
357
358
        "user": gb.ItemSet(user_ids, names="seed_nodes"),
        "item": gb.ItemSet(item_ids, names="seed_nodes"),
Rhett Ying's avatar
Rhett Ying committed
359
360
361
362
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
363
    item_set = gb.ItemSetDict(ids)
364
    assert item_set.names == ("seed_nodes",)
365
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
366
367
368
369
370
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        assert item_set[i] == item
        assert item_set[i - len(item_set)] == item
    # Indexing all with a slice.
    assert torch.equal(item_set[:]["user"], user_ids)
    assert torch.equal(item_set[:]["item"], item_ids)
    # Indexing partial with a slice.
    partial_data = item_set[:3]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["user"], user_ids[:3])
    partial_data = item_set[7:]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["item"], item_ids[2:])
    partial_data = item_set[3:7]
    assert len(list(partial_data.keys())) == 2
    assert torch.equal(partial_data["user"], user_ids[3:5])
    assert torch.equal(partial_data["item"], item_ids[:2])

    # Exception cases.
    with pytest.raises(AssertionError, match="Step must be 1."):
        _ = item_set[::2]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[5:3]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[-1:3]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[20]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[-20]
    with pytest.raises(
        TypeError, match="ItemSetDict indices must be int or slice."
    ):
        _ = item_set[torch.arange(3)]
Rhett Ying's avatar
Rhett Ying committed
407
408


409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def test_ItemSetDict_iteration_seed_nodes_labels():
    # Node IDs and labels.
    user_ids = torch.arange(0, 5)
    user_labels = torch.randint(0, 3, (5,))
    item_ids = torch.arange(5, 10)
    item_labels = torch.randint(0, 3, (5,))
    ids_labels = {
        "user": gb.ItemSet(
            (user_ids, user_labels), names=("seed_nodes", "labels")
        ),
        "item": gb.ItemSet(
            (item_ids, item_labels), names=("seed_nodes", "labels")
        ),
    }
    chained_ids = []
    for key, value in ids_labels.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids_labels)
    assert item_set.names == ("seed_nodes", "labels")
428
    # Iterating over ItemSetDict and indexing one by one.
429
430
431
432
433
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
434
435
436
437
438
439
        assert item_set[i] == item
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user"][0], user_ids)
    assert torch.equal(item_set[:]["user"][1], user_labels)
    assert torch.equal(item_set[:]["item"][0], item_ids)
    assert torch.equal(item_set[:]["item"][1], item_labels)
440
441
442


def test_ItemSetDict_iteration_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
443
    # Node pairs.
444
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
445
    node_pairs_dict = {
446
447
        "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
        "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
Rhett Ying's avatar
Rhett Ying committed
448
449
450
451
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
452
    item_set = gb.ItemSetDict(node_pairs_dict)
453
    assert item_set.names == ("node_pairs",)
454
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
455
456
457
458
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
459
        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
460
461
462
463
464
465
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key], item[key])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"], node_pairs)
Rhett Ying's avatar
Rhett Ying committed
466
467


468
def test_ItemSetDict_iteration_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
469
    # Node pairs and labels
470
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
471
    labels = torch.randint(0, 3, (5,))
472
473
474
475
476
477
478
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
Rhett Ying's avatar
Rhett Ying committed
479
480
    }
    expected_data = []
481
    for key, value in node_pairs_labels.items():
Rhett Ying's avatar
Rhett Ying committed
482
        expected_data += [(key, v) for v in value]
483
484
    item_set = gb.ItemSetDict(node_pairs_labels)
    assert item_set.names == ("node_pairs", "labels")
485
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
486
487
488
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
489
490
491
492
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert item[key][1] == value[1]
493
494
495
496
497
498
499
500
501
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
    assert torch.equal(item_set[:]["user:like:item"][1], labels)
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"][1], labels)
502
503
504
505
506
507
508
509


def test_ItemSetDict_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    node_pairs_neg_dsts = {
        "user:like:item": gb.ItemSet(
510
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
511
512
        ),
        "user:follow:user": gb.ItemSet(
513
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
514
        ),
Rhett Ying's avatar
Rhett Ying committed
515
516
    }
    expected_data = []
517
    for key, value in node_pairs_neg_dsts.items():
Rhett Ying's avatar
Rhett Ying committed
518
        expected_data += [(key, v) for v in value]
519
    item_set = gb.ItemSetDict(node_pairs_neg_dsts)
520
    assert item_set.names == ("node_pairs", "negative_dsts")
521
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
522
523
524
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
525
526
527
528
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert torch.equal(item[key][1], value[1])
529
530
531
532
533
534
535
536
537
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
    assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts)
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts)
538
539
540
541
542


def test_ItemSet_repr():
    # ItemSet with single name.
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
543
544
545
546
547
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]),),\n"
        "    names=('seed_nodes',),\n"
        ")"
548
    )
549
550

    assert str(item_set) == expected_str, item_set
551
552
553
554
555
556

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
    )
557
558
559
560
561
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
        "    names=('seed_nodes', 'labels'),\n"
        ")"
562
    )
563
    assert str(item_set) == expected_str, item_set
564
565
566
567
568
569
570
571
572
573


def test_ItemSetDict_repr():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
        }
    )
574
575
576
577
578
579
580
581
582
583
584
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]),),\n"
        "                 names=('seed_nodes',),\n"
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]),),\n"
        "                 names=('seed_nodes',),\n"
        "             )},\n"
        "    names=('seed_nodes',),\n"
        ")"
585
    )
586
    assert str(item_set) == expected_str, item_set
587
588
589
590
591
592
593
594
595
596
597
598
599
600

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
                names=("seed_nodes", "labels"),
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
                names=("seed_nodes", "labels"),
            ),
        }
    )
601
602
603
604
605
606
607
608
609
610
611
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
        "                 names=('seed_nodes', 'labels'),\n"
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
        "                 names=('seed_nodes', 'labels'),\n"
        "             )},\n"
        "    names=('seed_nodes', 'labels'),\n"
        ")"
612
    )
613
    assert str(item_set) == expected_str, item_set