test_itemset.py 11.7 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
from dgl import graphbolt as gb
Rhett Ying's avatar
Rhett Ying committed
7
from torch.testing import assert_close
8
9


10
11
def test_ItemSet_names():
    # ItemSet with single name.
12
13
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
14
15
16

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
17
18
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
19
    )
20
    assert item_set.names == ("seed_nodes", "labels")
21

22
    # ItemSet without name.
23
24
25
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

26
27
28
29
30
31
32
33
34
    # Integer-initiated ItemSet with excessive names.
    with pytest.raises(
        AssertionError,
        match=re.escape(
            "Number of names mustn't exceed 1 when item is an integer."
        ),
    ):
        _ = gb.ItemSet(5, names=("seed_nodes", "labels"))

35
36
37
38
39
    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
40
41
42
43
        _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))


def test_ItemSet_length():
44
45
46
47
48
49
50
51
    # Integer with valid length
    num = 10
    item_set = gb.ItemSet(num)
    assert len(item_set) == 10
    # Test __iter__() method. Same as below.
    for i, item in enumerate(item_set):
        assert i == item

52
53
54
55
    # Single iterable with valid length.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5
56
57
    for i, item in enumerate(item_set):
        assert i == item.item()
58
59
60
61

    # Tuple of iterables with valid length.
    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
    assert len(item_set) == 5
62
63
64
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1.item()
        assert i + 5 == item2.item()
65
66
67
68
69
70
71
72
73

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable with invalid length.
    item_set = gb.ItemSet(InvalidLength())
    with pytest.raises(TypeError):
        _ = len(item_set)
74
75
    for i, item in enumerate(item_set):
        assert i == item
76
77
78
79
80

    # Tuple of iterables with invalid length.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
    with pytest.raises(TypeError):
        _ = len(item_set)
81
82
83
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1
        assert i == item2
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130


def test_ItemSet_iteration_seed_nodes():
    # Node IDs.
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
    for i, item in enumerate(item_set):
        assert i == item.item()


def test_ItemSet_iteration_seed_nodes_labels():
    # Node IDs and labels.
    seed_nodes = torch.arange(0, 5)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
    assert item_set.names == ("seed_nodes", "labels")
    for i, (seed_node, label) in enumerate(item_set):
        assert seed_node == seed_nodes[i]
        assert label == labels[i]


def test_ItemSet_iteration_node_pairs():
    # Node pairs.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    item_set = gb.ItemSet(node_pairs, names="node_pairs")
    assert item_set.names == ("node_pairs",)
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[i][0] == src
        assert node_pairs[i][1] == dst


def test_ItemSet_iteration_node_pairs_labels():
    # Node pairs and labels
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    assert item_set.names == ("node_pairs", "labels")
    for i, (node_pair, label) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert labels[i] == label


def test_ItemSet_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    item_set = gb.ItemSet(
131
        (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
132
    )
133
    assert item_set.names == ("node_pairs", "negative_dsts")
134
135
136
137
138
139
140
141
142
143
144
145
    for i, (node_pair, neg_dst) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert torch.equal(neg_dsts[i], neg_dst)


def test_ItemSet_iteration_graphs():
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
    item_set = gb.ItemSet(graphs)
    assert item_set.names is None
    for i, item in enumerate(item_set):
        assert graphs[i] == item
146
147
148
149
150
151


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
152
153
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
154
155
        }
    )
156
    assert item_set.names == ("seed_nodes",)
157
158
159
160
161
162

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
163
                names=("seed_nodes", "labels"),
164
165
166
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
167
                names=("seed_nodes", "labels"),
168
169
170
            ),
        }
    )
171
    assert item_set.names == ("seed_nodes", "labels")
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
191
                    names=("seed_nodes", "labels"),
192
193
                ),
                "item": gb.ItemSet(
194
                    (torch.arange(5, 10),), names=("seed_nodes",)
195
196
197
198
199
                ),
            }
        )


200
201
def test_ItemSetDict_length():
    # Single iterable with valid length.
202
203
204
205
206
207
208
209
210
211
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

212
213
214
215
216
    # Tuple of iterables with valid length.
    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
217
218
    item_set = gb.ItemSetDict(
        {
219
220
221
222
            "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
            "user:follow:user": gb.ItemSet(
                (node_pairs_follow, neg_dsts_follow)
            ),
223
224
        }
    )
225
    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
226
227
228
229
230

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

231
    # Single iterable with invalid length.
232
233
234
235
236
237
238
239
240
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)

241
    # Tuple of iterables with invalid length.
242
243
    item_set = gb.ItemSetDict(
        {
244
245
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
246
247
248
249
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)
Rhett Ying's avatar
Rhett Ying committed
250
251


252
253
254
255
def test_ItemSetDict_iteration_seed_nodes():
    # Node IDs.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(5, 10)
Rhett Ying's avatar
Rhett Ying committed
256
    ids = {
257
258
        "user": gb.ItemSet(user_ids, names="seed_nodes"),
        "item": gb.ItemSet(item_ids, names="seed_nodes"),
Rhett Ying's avatar
Rhett Ying committed
259
260
261
262
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
263
    item_set = gb.ItemSetDict(ids)
264
    assert item_set.names == ("seed_nodes",)
Rhett Ying's avatar
Rhett Ying committed
265
266
267
268
269
270
271
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]


272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def test_ItemSetDict_iteration_seed_nodes_labels():
    # Node IDs and labels.
    user_ids = torch.arange(0, 5)
    user_labels = torch.randint(0, 3, (5,))
    item_ids = torch.arange(5, 10)
    item_labels = torch.randint(0, 3, (5,))
    ids_labels = {
        "user": gb.ItemSet(
            (user_ids, user_labels), names=("seed_nodes", "labels")
        ),
        "item": gb.ItemSet(
            (item_ids, item_labels), names=("seed_nodes", "labels")
        ),
    }
    chained_ids = []
    for key, value in ids_labels.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids_labels)
    assert item_set.names == ("seed_nodes", "labels")
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]


def test_ItemSetDict_iteration_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
299
    # Node pairs.
300
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
301
    node_pairs_dict = {
302
303
        "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
        "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
Rhett Ying's avatar
Rhett Ying committed
304
305
306
307
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
308
    item_set = gb.ItemSetDict(node_pairs_dict)
309
    assert item_set.names == ("node_pairs",)
Rhett Ying's avatar
Rhett Ying committed
310
311
312
313
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
314
        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
Rhett Ying's avatar
Rhett Ying committed
315
316


317
def test_ItemSetDict_iteration_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
318
    # Node pairs and labels
319
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
320
    labels = torch.randint(0, 3, (5,))
321
322
323
324
325
326
327
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
Rhett Ying's avatar
Rhett Ying committed
328
329
    }
    expected_data = []
330
    for key, value in node_pairs_labels.items():
Rhett Ying's avatar
Rhett Ying committed
331
        expected_data += [(key, v) for v in value]
332
333
    item_set = gb.ItemSetDict(node_pairs_labels)
    assert item_set.names == ("node_pairs", "labels")
Rhett Ying's avatar
Rhett Ying committed
334
335
336
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
337
338
339
340
341
342
343
344
345
346
347
348
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert item[key][1] == value[1]


def test_ItemSetDict_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    node_pairs_neg_dsts = {
        "user:like:item": gb.ItemSet(
349
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
350
351
        ),
        "user:follow:user": gb.ItemSet(
352
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
353
        ),
Rhett Ying's avatar
Rhett Ying committed
354
355
    }
    expected_data = []
356
    for key, value in node_pairs_neg_dsts.items():
Rhett Ying's avatar
Rhett Ying committed
357
        expected_data += [(key, v) for v in value]
358
    item_set = gb.ItemSetDict(node_pairs_neg_dsts)
359
    assert item_set.names == ("node_pairs", "negative_dsts")
Rhett Ying's avatar
Rhett Ying committed
360
361
362
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
363
364
365
366
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert torch.equal(item[key][1], value[1])