test_itemset.py 21 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
from dgl import graphbolt as gb
Rhett Ying's avatar
Rhett Ying committed
7
from torch.testing import assert_close
8
9


10
11
def test_ItemSet_names():
    # ItemSet with single name.
12
13
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
14
15
16

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
17
18
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
19
    )
20
    assert item_set.names == ("seed_nodes", "labels")
21

22
    # ItemSet without name.
23
24
25
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

26
27
28
    # Integer-initiated ItemSet with excessive names.
    with pytest.raises(
        AssertionError,
29
        match=re.escape("Number of items (1) and names (2) must match."),
30
31
32
    ):
        _ = gb.ItemSet(5, names=("seed_nodes", "labels"))

33
34
35
36
37
    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
38
39
40
41
        _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))


def test_ItemSet_length():
42
43
44
45
46
47
48
49
    # Integer with valid length
    num = 10
    item_set = gb.ItemSet(num)
    assert len(item_set) == 10
    # Test __iter__() method. Same as below.
    for i, item in enumerate(item_set):
        assert i == item

50
51
52
53
    # Single iterable with valid length.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5
54
55
    for i, item in enumerate(item_set):
        assert i == item.item()
56
57
58
59

    # Tuple of iterables with valid length.
    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
    assert len(item_set) == 5
60
61
62
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1.item()
        assert i + 5 == item2.item()
63
64
65
66
67
68
69

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable with invalid length.
    item_set = gb.ItemSet(InvalidLength())
70
71
72
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
73
        _ = len(item_set)
74
75
76
77
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
78
79
    for i, item in enumerate(item_set):
        assert i == item
80
81
82

    # Tuple of iterables with invalid length.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
83
84
85
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
86
        _ = len(item_set)
87
88
89
90
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
91
92
93
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1
        assert i == item2
94
95


96
97
def test_ItemSet_seed_nodes():
    # Node IDs with tensor.
98
99
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
100
101
102
103
104
105
106
107
108
109
110
111
112
    # Iterating over ItemSet and indexing one by one.
    for i, item in enumerate(item_set):
        assert i == item.item()
        assert i == item_set[i]
    # Indexing with a slice.
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))

    # Node IDs with single integer.
    item_set = gb.ItemSet(5, names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
    # Iterating over ItemSet and indexing one by one.
113
114
    for i, item in enumerate(item_set):
        assert i == item.item()
115
116
        assert i == item_set[i]
    # Indexing with a slice.
117
118
119
120
121
122
123
124
125
126
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an integer.
    assert item_set[0] == 0
    assert item_set[-1] == 4
    # Indexing that is out of range.
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[5]
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[-10]
    # Indexing with tensor.
127
    with pytest.raises(
128
        TypeError, match="ItemSet indices must be integer or slice."
129
    ):
130
        _ = item_set[torch.arange(3)]
131
132


133
def test_ItemSet_seed_nodes_labels():
134
135
136
137
138
    # Node IDs and labels.
    seed_nodes = torch.arange(0, 5)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
    assert item_set.names == ("seed_nodes", "labels")
139
    # Iterating over ItemSet and indexing one by one.
140
141
142
    for i, (seed_node, label) in enumerate(item_set):
        assert seed_node == seed_nodes[i]
        assert label == labels[i]
143
144
145
146
147
148
149
150
        assert seed_node == item_set[i][0]
        assert label == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], seed_nodes)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], seed_nodes)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
151
152


153
def test_ItemSet_node_pairs():
154
155
156
157
    # Node pairs.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    item_set = gb.ItemSet(node_pairs, names="node_pairs")
    assert item_set.names == ("node_pairs",)
158
    # Iterating over ItemSet and indexing one by one.
159
160
161
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[i][0] == src
        assert node_pairs[i][1] == dst
162
163
164
165
166
167
        assert node_pairs[i][0] == item_set[i][0]
        assert node_pairs[i][1] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:], node_pairs)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], node_pairs)
168
169


170
def test_ItemSet_node_pairs_labels():
171
172
173
174
175
    # Node pairs and labels
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    assert item_set.names == ("node_pairs", "labels")
176
    # Iterating over ItemSet and indexing one by one.
177
178
179
    for i, (node_pair, label) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert labels[i] == label
180
181
182
183
184
185
186
187
        assert torch.equal(node_pairs[i], item_set[i][0])
        assert labels[i] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
188
189


190
def test_ItemSet_node_pairs_neg_dsts():
191
192
193
194
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    item_set = gb.ItemSet(
195
        (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
196
    )
197
    assert item_set.names == ("node_pairs", "negative_dsts")
198
    # Iterating over ItemSet and indexing one by one.
199
200
201
    for i, (node_pair, neg_dst) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert torch.equal(neg_dsts[i], neg_dst)
202
203
204
205
206
207
208
209
        assert torch.equal(node_pairs[i], item_set[i][0])
        assert torch.equal(neg_dsts[i], item_set[i][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
    assert torch.equal(item_set[:][1], neg_dsts)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
    assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts)
210
211


212
def test_ItemSet_graphs():
213
214
215
216
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
    item_set = gb.ItemSet(graphs)
    assert item_set.names is None
217
    # Iterating over ItemSet and indexing one by one.
218
219
    for i, item in enumerate(item_set):
        assert graphs[i] == item
220
221
222
        assert graphs[i] == item_set[i]
    # Indexing with a slice.
    assert item_set[:] == graphs
223
224
225
226
227
228


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
229
230
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
231
232
        }
    )
233
    assert item_set.names == ("seed_nodes",)
234
235
236
237
238
239

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
240
                names=("seed_nodes", "labels"),
241
242
243
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
244
                names=("seed_nodes", "labels"),
245
246
247
            ),
        }
    )
248
    assert item_set.names == ("seed_nodes", "labels")
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
268
                    names=("seed_nodes", "labels"),
269
270
                ),
                "item": gb.ItemSet(
271
                    (torch.arange(5, 10),), names=("seed_nodes",)
272
273
274
275
276
                ),
            }
        )


277
278
def test_ItemSetDict_length():
    # Single iterable with valid length.
279
280
281
282
283
284
285
286
287
288
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

289
290
291
292
293
    # Tuple of iterables with valid length.
    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
294
295
    item_set = gb.ItemSetDict(
        {
296
297
298
299
            "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
            "user:follow:user": gb.ItemSet(
                (node_pairs_follow, neg_dsts_follow)
            ),
300
301
        }
    )
302
    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
303
304
305
306
307

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

308
    # Single iterable with invalid length.
309
310
311
312
313
314
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
315
316
317
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
318
        _ = len(item_set)
319
320
321
322
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
323

324
    # Tuple of iterables with invalid length.
325
326
    item_set = gb.ItemSetDict(
        {
327
328
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
329
330
        }
    )
331
332
333
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
334
        _ = len(item_set)
335
336
337
338
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
Rhett Ying's avatar
Rhett Ying committed
339
340


341
342
343
344
def test_ItemSetDict_iteration_seed_nodes():
    # Node IDs.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(5, 10)
Rhett Ying's avatar
Rhett Ying committed
345
    ids = {
346
347
        "user": gb.ItemSet(user_ids, names="seed_nodes"),
        "item": gb.ItemSet(item_ids, names="seed_nodes"),
Rhett Ying's avatar
Rhett Ying committed
348
349
350
351
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
352
    item_set = gb.ItemSetDict(ids)
353
    assert item_set.names == ("seed_nodes",)
354
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
355
356
357
358
359
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        assert item_set[i] == item
        assert item_set[i - len(item_set)] == item
    # Indexing all with a slice.
    assert torch.equal(item_set[:]["user"], user_ids)
    assert torch.equal(item_set[:]["item"], item_ids)
    # Indexing partial with a slice.
    partial_data = item_set[:3]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["user"], user_ids[:3])
    partial_data = item_set[7:]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["item"], item_ids[2:])
    partial_data = item_set[3:7]
    assert len(list(partial_data.keys())) == 2
    assert torch.equal(partial_data["user"], user_ids[3:5])
    assert torch.equal(partial_data["item"], item_ids[:2])

    # Exception cases.
    with pytest.raises(AssertionError, match="Step must be 1."):
        _ = item_set[::2]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[5:3]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[-1:3]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[20]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[-20]
    with pytest.raises(
        TypeError, match="ItemSetDict indices must be int or slice."
    ):
        _ = item_set[torch.arange(3)]
Rhett Ying's avatar
Rhett Ying committed
396
397


398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def test_ItemSetDict_iteration_seed_nodes_labels():
    # Node IDs and labels.
    user_ids = torch.arange(0, 5)
    user_labels = torch.randint(0, 3, (5,))
    item_ids = torch.arange(5, 10)
    item_labels = torch.randint(0, 3, (5,))
    ids_labels = {
        "user": gb.ItemSet(
            (user_ids, user_labels), names=("seed_nodes", "labels")
        ),
        "item": gb.ItemSet(
            (item_ids, item_labels), names=("seed_nodes", "labels")
        ),
    }
    chained_ids = []
    for key, value in ids_labels.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids_labels)
    assert item_set.names == ("seed_nodes", "labels")
417
    # Iterating over ItemSetDict and indexing one by one.
418
419
420
421
422
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
423
424
425
426
427
428
        assert item_set[i] == item
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user"][0], user_ids)
    assert torch.equal(item_set[:]["user"][1], user_labels)
    assert torch.equal(item_set[:]["item"][0], item_ids)
    assert torch.equal(item_set[:]["item"][1], item_labels)
429
430
431


def test_ItemSetDict_iteration_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
432
    # Node pairs.
433
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
434
    node_pairs_dict = {
435
436
        "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
        "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
Rhett Ying's avatar
Rhett Ying committed
437
438
439
440
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
441
    item_set = gb.ItemSetDict(node_pairs_dict)
442
    assert item_set.names == ("node_pairs",)
443
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
444
445
446
447
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
448
        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
449
450
451
452
453
454
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key], item[key])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"], node_pairs)
Rhett Ying's avatar
Rhett Ying committed
455
456


457
def test_ItemSetDict_iteration_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
458
    # Node pairs and labels
459
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
460
    labels = torch.randint(0, 3, (5,))
461
462
463
464
465
466
467
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
Rhett Ying's avatar
Rhett Ying committed
468
469
    }
    expected_data = []
470
    for key, value in node_pairs_labels.items():
Rhett Ying's avatar
Rhett Ying committed
471
        expected_data += [(key, v) for v in value]
472
473
    item_set = gb.ItemSetDict(node_pairs_labels)
    assert item_set.names == ("node_pairs", "labels")
474
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
475
476
477
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
478
479
480
481
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert item[key][1] == value[1]
482
483
484
485
486
487
488
489
490
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
    assert torch.equal(item_set[:]["user:like:item"][1], labels)
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"][1], labels)
491
492
493
494
495
496
497
498


def test_ItemSetDict_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    node_pairs_neg_dsts = {
        "user:like:item": gb.ItemSet(
499
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
500
501
        ),
        "user:follow:user": gb.ItemSet(
502
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
503
        ),
Rhett Ying's avatar
Rhett Ying committed
504
505
    }
    expected_data = []
506
    for key, value in node_pairs_neg_dsts.items():
Rhett Ying's avatar
Rhett Ying committed
507
        expected_data += [(key, v) for v in value]
508
    item_set = gb.ItemSetDict(node_pairs_neg_dsts)
509
    assert item_set.names == ("node_pairs", "negative_dsts")
510
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
511
512
513
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
514
515
516
517
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert torch.equal(item[key][1], value[1])
518
519
520
521
522
523
524
525
526
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
    assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts)
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts)
527
528
529
530
531


def test_ItemSet_repr():
    # ItemSet with single name.
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
532
533
534
535
536
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]),),\n"
        "    names=('seed_nodes',),\n"
        ")"
537
    )
538
539

    assert str(item_set) == expected_str, item_set
540
541
542
543
544
545

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
    )
546
547
548
549
550
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
        "    names=('seed_nodes', 'labels'),\n"
        ")"
551
    )
552
    assert str(item_set) == expected_str, item_set
553
554
555
556
557
558
559
560
561
562


def test_ItemSetDict_repr():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
        }
    )
563
564
565
566
567
568
569
570
571
572
573
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]),),\n"
        "                 names=('seed_nodes',),\n"
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]),),\n"
        "                 names=('seed_nodes',),\n"
        "             )},\n"
        "    names=('seed_nodes',),\n"
        ")"
574
    )
575
    assert str(item_set) == expected_str, item_set
576
577
578
579
580
581
582
583
584
585
586
587
588
589

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
                names=("seed_nodes", "labels"),
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
                names=("seed_nodes", "labels"),
            ),
        }
    )
590
591
592
593
594
595
596
597
598
599
600
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
        "                 names=('seed_nodes', 'labels'),\n"
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
        "                 names=('seed_nodes', 'labels'),\n"
        "             )},\n"
        "    names=('seed_nodes', 'labels'),\n"
        ")"
601
    )
602
    assert str(item_set) == expected_str, item_set