test_itemset.py 21.6 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
7
8
from dgl import graphbolt as gb


9
10
def test_ItemSet_names():
    # ItemSet with single name.
11
12
    item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
    assert item_set.names == ("seeds",)
13
14
15

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
16
        (torch.arange(0, 5), torch.arange(5, 10)),
17
        names=("seeds", "labels"),
18
    )
19
    assert item_set.names == ("seeds", "labels")
20

21
    # ItemSet without name.
22
23
24
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

25
26
27
    # Integer-initiated ItemSet with excessive names.
    with pytest.raises(
        AssertionError,
28
        match=re.escape("Number of items (1) and names (2) must match."),
29
    ):
30
        _ = gb.ItemSet(5, names=("seeds", "labels"))
31

32
33
34
35
36
    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
37
        _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels"))
38
39


40
41
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype):
42
    item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seeds")
43
44
45
46
47
48
49
50
51
    for i, item in enumerate(item_set):
        assert i == item
        assert item.dtype == dtype
    assert item_set[2] == torch.tensor(2, dtype=dtype)
    assert torch.equal(
        item_set[slice(1, 4, 2)], torch.arange(1, 4, 2, dtype=dtype)
    )


52
def test_ItemSet_length():
53
54
55
56
57
58
59
60
    # Integer with valid length
    num = 10
    item_set = gb.ItemSet(num)
    assert len(item_set) == 10
    # Test __iter__() method. Same as below.
    for i, item in enumerate(item_set):
        assert i == item

61
62
63
64
    # Single iterable with valid length.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5
65
66
    for i, item in enumerate(item_set):
        assert i == item.item()
67
68
69
70

    # Tuple of iterables with valid length.
    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
    assert len(item_set) == 5
71
72
73
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1.item()
        assert i + 5 == item2.item()
74
75
76
77
78
79
80

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable with invalid length.
    item_set = gb.ItemSet(InvalidLength())
81
82
83
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
84
        _ = len(item_set)
85
86
87
88
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
89
90
    for i, item in enumerate(item_set):
        assert i == item
91
92
93

    # Tuple of iterables with invalid length.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
94
95
96
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
97
        _ = len(item_set)
98
99
100
101
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't support indexing."
    ):
        _ = item_set[0]
102
103
104
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1
        assert i == item2
105
106


107
108
def test_ItemSet_seed_nodes():
    # Node IDs with tensor.
109
110
    item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
    assert item_set.names == ("seeds",)
111
112
113
114
115
116
117
118
119
120
    # Iterating over ItemSet and indexing one by one.
    for i, item in enumerate(item_set):
        assert i == item.item()
        assert i == item_set[i]
    # Indexing with a slice.
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))

    # Node IDs with single integer.
121
122
    item_set = gb.ItemSet(5, names="seeds")
    assert item_set.names == ("seeds",)
123
    # Iterating over ItemSet and indexing one by one.
124
125
    for i, item in enumerate(item_set):
        assert i == item.item()
126
127
        assert i == item_set[i]
    # Indexing with a slice.
128
129
130
131
132
133
134
135
136
137
    assert torch.equal(item_set[:], torch.arange(0, 5))
    # Indexing with an integer.
    assert item_set[0] == 0
    assert item_set[-1] == 4
    # Indexing that is out of range.
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[5]
    with pytest.raises(IndexError, match="ItemSet index out of range."):
        _ = item_set[-10]
    # Indexing with tensor.
138
    with pytest.raises(
139
        TypeError, match="ItemSet indices must be integer or slice."
140
    ):
141
        _ = item_set[torch.arange(3)]
142
143


144
def test_ItemSet_seed_nodes_labels():
145
146
147
    # Node IDs and labels.
    seed_nodes = torch.arange(0, 5)
    labels = torch.randint(0, 3, (5,))
148
149
    item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
    assert item_set.names == ("seeds", "labels")
150
    # Iterating over ItemSet and indexing one by one.
151
152
153
    for i, (seed_node, label) in enumerate(item_set):
        assert seed_node == seed_nodes[i]
        assert label == labels[i]
154
155
156
157
158
159
160
161
        assert seed_node == item_set[i][0]
        assert label == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], seed_nodes)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], seed_nodes)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
162
163


164
def test_ItemSet_node_pairs():
165
166
    # Node pairs.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
167
168
    item_set = gb.ItemSet(node_pairs, names="seeds")
    assert item_set.names == ("seeds",)
169
    # Iterating over ItemSet and indexing one by one.
170
171
172
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[i][0] == src
        assert node_pairs[i][1] == dst
173
174
175
176
177
178
        assert node_pairs[i][0] == item_set[i][0]
        assert node_pairs[i][1] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:], node_pairs)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)], node_pairs)
179
180


181
def test_ItemSet_node_pairs_labels():
182
183
184
    # Node pairs and labels
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    labels = torch.randint(0, 3, (5,))
185
186
    item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
    assert item_set.names == ("seeds", "labels")
187
    # Iterating over ItemSet and indexing one by one.
188
189
190
    for i, (node_pair, label) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert labels[i] == label
191
192
193
194
195
196
197
198
        assert torch.equal(node_pairs[i], item_set[i][0])
        assert labels[i] == item_set[i][1]
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
    assert torch.equal(item_set[:][1], labels)
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
199
200


201
def test_ItemSet_node_pairs_labels_indexes():
202
203
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
204
205
    labels = torch.tensor([1, 1, 0, 0, 0])
    indexes = torch.tensor([0, 1, 0, 0, 1])
206
    item_set = gb.ItemSet(
207
        (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
208
    )
209
    assert item_set.names == ("seeds", "labels", "indexes")
210
    # Iterating over ItemSet and indexing one by one.
211
    for i, (node_pair, label, index) in enumerate(item_set):
212
        assert torch.equal(node_pairs[i], node_pair)
213
214
        assert torch.equal(labels[i], label)
        assert torch.equal(indexes[i], index)
215
        assert torch.equal(node_pairs[i], item_set[i][0])
216
217
        assert torch.equal(labels[i], item_set[i][1])
        assert torch.equal(indexes[i], item_set[i][2])
218
219
    # Indexing with a slice.
    assert torch.equal(item_set[:][0], node_pairs)
220
221
    assert torch.equal(item_set[:][1], labels)
    assert torch.equal(item_set[:][2], indexes)
222
223
    # Indexing with an Iterable.
    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
224
225
    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
    assert torch.equal(item_set[torch.arange(0, 5)][2], indexes)
226
227


228
def test_ItemSet_graphs():
229
230
231
232
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
    item_set = gb.ItemSet(graphs)
    assert item_set.names is None
233
    # Iterating over ItemSet and indexing one by one.
234
235
    for i, item in enumerate(item_set):
        assert graphs[i] == item
236
237
238
        assert graphs[i] == item_set[i]
    # Indexing with a slice.
    assert item_set[:] == graphs
239
240
241
242
243
244


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
245
246
            "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
247
248
        }
    )
249
    assert item_set.names == ("seeds",)
250
251
252
253
254
255

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
256
                names=("seeds", "labels"),
257
258
259
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
260
                names=("seeds", "labels"),
261
262
263
            ),
        }
    )
264
    assert item_set.names == ("seeds", "labels")
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
284
                    names=("seeds", "labels"),
285
                ),
286
                "item": gb.ItemSet((torch.arange(5, 10),), names=("seeds",)),
287
288
289
290
            }
        )


291
292
def test_ItemSetDict_length():
    # Single iterable with valid length.
293
294
295
296
297
298
299
300
301
302
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

303
304
305
306
307
    # Tuple of iterables with valid length.
    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
308
309
    item_set = gb.ItemSetDict(
        {
310
311
312
313
            "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
            "user:follow:user": gb.ItemSet(
                (node_pairs_follow, neg_dsts_follow)
            ),
314
315
        }
    )
316
    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
317
318
319
320
321

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

322
    # Single iterable with invalid length.
323
324
325
326
327
328
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
329
330
331
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
332
        _ = len(item_set)
333
334
335
336
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
337

338
    # Tuple of iterables with invalid length.
339
340
    item_set = gb.ItemSetDict(
        {
341
342
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
343
344
        }
    )
345
346
347
    with pytest.raises(
        TypeError, match="ItemSet instance doesn't have valid length."
    ):
348
        _ = len(item_set)
349
350
351
352
    with pytest.raises(
        TypeError, match="ItemSetDict instance doesn't support indexing."
    ):
        _ = item_set[0]
Rhett Ying's avatar
Rhett Ying committed
353
354


355
356
357
358
def test_ItemSetDict_iteration_seed_nodes():
    # Node IDs.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(5, 10)
Rhett Ying's avatar
Rhett Ying committed
359
    ids = {
360
361
        "user": gb.ItemSet(user_ids, names="seeds"),
        "item": gb.ItemSet(item_ids, names="seeds"),
Rhett Ying's avatar
Rhett Ying committed
362
363
364
365
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
366
    item_set = gb.ItemSetDict(ids)
367
    assert item_set.names == ("seeds",)
368
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
369
370
371
372
373
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        assert item_set[i] == item
        assert item_set[i - len(item_set)] == item
    # Indexing all with a slice.
    assert torch.equal(item_set[:]["user"], user_ids)
    assert torch.equal(item_set[:]["item"], item_ids)
    # Indexing partial with a slice.
    partial_data = item_set[:3]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["user"], user_ids[:3])
    partial_data = item_set[7:]
    assert len(list(partial_data.keys())) == 1
    assert torch.equal(partial_data["item"], item_ids[2:])
    partial_data = item_set[3:7]
    assert len(list(partial_data.keys())) == 2
    assert torch.equal(partial_data["user"], user_ids[3:5])
    assert torch.equal(partial_data["item"], item_ids[:2])

    # Exception cases.
    with pytest.raises(AssertionError, match="Step must be 1."):
        _ = item_set[::2]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[5:3]
    with pytest.raises(
        AssertionError, match="Start must be smaller than stop."
    ):
        _ = item_set[-1:3]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[20]
    with pytest.raises(IndexError, match="ItemSetDict index out of range."):
        _ = item_set[-20]
    with pytest.raises(
        TypeError, match="ItemSetDict indices must be int or slice."
    ):
        _ = item_set[torch.arange(3)]
Rhett Ying's avatar
Rhett Ying committed
410
411


412
413
414
415
416
417
418
def test_ItemSetDict_iteration_seed_nodes_labels():
    # Node IDs and labels.
    user_ids = torch.arange(0, 5)
    user_labels = torch.randint(0, 3, (5,))
    item_ids = torch.arange(5, 10)
    item_labels = torch.randint(0, 3, (5,))
    ids_labels = {
419
420
        "user": gb.ItemSet((user_ids, user_labels), names=("seeds", "labels")),
        "item": gb.ItemSet((item_ids, item_labels), names=("seeds", "labels")),
421
422
423
424
425
    }
    chained_ids = []
    for key, value in ids_labels.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids_labels)
426
    assert item_set.names == ("seeds", "labels")
427
    # Iterating over ItemSetDict and indexing one by one.
428
429
430
431
432
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]
433
434
435
436
437
438
        assert item_set[i] == item
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user"][0], user_ids)
    assert torch.equal(item_set[:]["user"][1], user_labels)
    assert torch.equal(item_set[:]["item"][0], item_ids)
    assert torch.equal(item_set[:]["item"][1], item_labels)
439
440
441


def test_ItemSetDict_iteration_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
442
    # Node pairs.
443
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
444
    node_pairs_dict = {
445
446
        "user:like:item": gb.ItemSet(node_pairs, names="seeds"),
        "user:follow:user": gb.ItemSet(node_pairs, names="seeds"),
Rhett Ying's avatar
Rhett Ying committed
447
448
449
450
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
451
    item_set = gb.ItemSetDict(node_pairs_dict)
452
    assert item_set.names == ("seeds",)
453
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
454
455
456
457
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
458
        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
459
460
461
462
463
464
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key], item[key])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"], node_pairs)
Rhett Ying's avatar
Rhett Ying committed
465
466


467
def test_ItemSetDict_iteration_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
468
    # Node pairs and labels
469
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
470
    labels = torch.randint(0, 3, (5,))
471
472
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
473
            (node_pairs, labels), names=("seeds", "labels")
474
475
        ),
        "user:follow:user": gb.ItemSet(
476
            (node_pairs, labels), names=("seeds", "labels")
477
        ),
Rhett Ying's avatar
Rhett Ying committed
478
479
    }
    expected_data = []
480
    for key, value in node_pairs_labels.items():
Rhett Ying's avatar
Rhett Ying committed
481
        expected_data += [(key, v) for v in value]
482
    item_set = gb.ItemSetDict(node_pairs_labels)
483
    assert item_set.names == ("seeds", "labels")
484
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
485
486
487
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
488
489
490
491
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert item[key][1] == value[1]
492
493
494
495
496
497
498
499
500
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
    assert torch.equal(item_set[:]["user:like:item"][1], labels)
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
    assert torch.equal(item_set[:]["user:follow:user"][1], labels)
501
502


503
def test_ItemSetDict_iteration_node_pairs_labels_indexes():
504
505
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
506
507
    labels = torch.tensor([1, 1, 0, 0, 0])
    indexes = torch.tensor([0, 1, 0, 0, 1])
508
509
    node_pairs_neg_dsts = {
        "user:like:item": gb.ItemSet(
510
            (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
511
512
        ),
        "user:follow:user": gb.ItemSet(
513
            (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
514
        ),
Rhett Ying's avatar
Rhett Ying committed
515
516
    }
    expected_data = []
517
    for key, value in node_pairs_neg_dsts.items():
Rhett Ying's avatar
Rhett Ying committed
518
        expected_data += [(key, v) for v in value]
519
    item_set = gb.ItemSetDict(node_pairs_neg_dsts)
520
    assert item_set.names == ("seeds", "labels", "indexes")
521
    # Iterating over ItemSetDict and indexing one by one.
Rhett Ying's avatar
Rhett Ying committed
522
523
524
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
525
526
527
528
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert torch.equal(item[key][1], value[1])
529
        assert torch.equal(item[key][2], value[2])
530
531
532
533
        assert item_set[i].keys() == item.keys()
        key = list(item.keys())[0]
        assert torch.equal(item_set[i][key][0], item[key][0])
        assert torch.equal(item_set[i][key][1], item[key][1])
534
        assert torch.equal(item_set[i][key][2], item[key][2])
535
536
    # Indexing with a slice.
    assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
537
538
    assert torch.equal(item_set[:]["user:like:item"][1], labels)
    assert torch.equal(item_set[:]["user:like:item"][2], indexes)
539
    assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
540
541
    assert torch.equal(item_set[:]["user:follow:user"][1], labels)
    assert torch.equal(item_set[:]["user:follow:user"][2], indexes)
542
543
544
545


def test_ItemSet_repr():
    # ItemSet with single name.
546
    item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
547
548
549
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]),),\n"
550
        "    names=('seeds',),\n"
551
        ")"
552
    )
553
554

    assert str(item_set) == expected_str, item_set
555
556
557
558

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
        (torch.arange(0, 5), torch.arange(5, 10)),
559
        names=("seeds", "labels"),
560
    )
561
562
563
    expected_str = (
        "ItemSet(\n"
        "    items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
564
        "    names=('seeds', 'labels'),\n"
565
        ")"
566
    )
567
    assert str(item_set) == expected_str, item_set
568
569
570
571
572
573


def test_ItemSetDict_repr():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
574
575
            "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
576
577
        }
    )
578
579
580
581
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]),),\n"
582
        "                 names=('seeds',),\n"
583
584
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]),),\n"
585
        "                 names=('seeds',),\n"
586
        "             )},\n"
587
        "    names=('seeds',),\n"
588
        ")"
589
    )
590
    assert str(item_set) == expected_str, item_set
591
592
593
594
595
596

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
597
                names=("seeds", "labels"),
598
599
600
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
601
                names=("seeds", "labels"),
602
603
604
            ),
        }
    )
605
606
607
608
    expected_str = (
        "ItemSetDict(\n"
        "    itemsets={'user': ItemSet(\n"
        "                 items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
609
        "                 names=('seeds', 'labels'),\n"
610
611
        "             ), 'item': ItemSet(\n"
        "                 items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
612
        "                 names=('seeds', 'labels'),\n"
613
        "             )},\n"
614
        "    names=('seeds', 'labels'),\n"
615
        ")"
616
    )
617
    assert str(item_set) == expected_str, item_set