test_itemset.py 11.3 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
from dgl import graphbolt as gb
Rhett Ying's avatar
Rhett Ying committed
7
from torch.testing import assert_close
8
9


10
11
def test_ItemSet_names():
    # ItemSet with single name.
12
13
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
14
15
16

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
17
18
        (torch.arange(0, 5), torch.arange(5, 10)),
        names=("seed_nodes", "labels"),
19
    )
20
    assert item_set.names == ("seed_nodes", "labels")
21

22
    # ItemSet without name.
23
24
25
26
27
28
29
30
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
31
32
33
34
35
36
37
38
        _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))


def test_ItemSet_length():
    # Single iterable with valid length.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5
39
40
41
    # Test __iter__ method. Same as below.
    for i, item in enumerate(item_set):
        assert i == item.item()
42
43
44
45

    # Tuple of iterables with valid length.
    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
    assert len(item_set) == 5
46
47
48
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1.item()
        assert i + 5 == item2.item()
49
50
51
52
53
54
55
56
57

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable with invalid length.
    item_set = gb.ItemSet(InvalidLength())
    with pytest.raises(TypeError):
        _ = len(item_set)
58
59
    for i, item in enumerate(item_set):
        assert i == item
60
61
62
63
64

    # Tuple of iterables with invalid length.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
    with pytest.raises(TypeError):
        _ = len(item_set)
65
66
67
    for i, (item1, item2) in enumerate(item_set):
        assert i == item1
        assert i == item2
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


def test_ItemSet_iteration_seed_nodes():
    # Node IDs.
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
    assert item_set.names == ("seed_nodes",)
    for i, item in enumerate(item_set):
        assert i == item.item()


def test_ItemSet_iteration_seed_nodes_labels():
    # Node IDs and labels.
    seed_nodes = torch.arange(0, 5)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
    assert item_set.names == ("seed_nodes", "labels")
    for i, (seed_node, label) in enumerate(item_set):
        assert seed_node == seed_nodes[i]
        assert label == labels[i]


def test_ItemSet_iteration_node_pairs():
    # Node pairs.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    item_set = gb.ItemSet(node_pairs, names="node_pairs")
    assert item_set.names == ("node_pairs",)
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[i][0] == src
        assert node_pairs[i][1] == dst


def test_ItemSet_iteration_node_pairs_labels():
    # Node pairs and labels
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    labels = torch.randint(0, 3, (5,))
    item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    assert item_set.names == ("node_pairs", "labels")
    for i, (node_pair, label) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert labels[i] == label


def test_ItemSet_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    item_set = gb.ItemSet(
115
        (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
116
    )
117
    assert item_set.names == ("node_pairs", "negative_dsts")
118
119
120
121
122
123
124
125
126
127
128
129
    for i, (node_pair, neg_dst) in enumerate(item_set):
        assert torch.equal(node_pairs[i], node_pair)
        assert torch.equal(neg_dsts[i], neg_dst)


def test_ItemSet_iteration_graphs():
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
    item_set = gb.ItemSet(graphs)
    assert item_set.names is None
    for i, item in enumerate(item_set):
        assert graphs[i] == item
130
131
132
133
134
135


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
136
137
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
138
139
        }
    )
140
    assert item_set.names == ("seed_nodes",)
141
142
143
144
145
146

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
147
                names=("seed_nodes", "labels"),
148
149
150
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
151
                names=("seed_nodes", "labels"),
152
153
154
            ),
        }
    )
155
    assert item_set.names == ("seed_nodes", "labels")
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
175
                    names=("seed_nodes", "labels"),
176
177
                ),
                "item": gb.ItemSet(
178
                    (torch.arange(5, 10),), names=("seed_nodes",)
179
180
181
182
183
                ),
            }
        )


184
185
def test_ItemSetDict_length():
    # Single iterable with valid length.
186
187
188
189
190
191
192
193
194
195
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

196
197
198
199
200
    # Tuple of iterables with valid length.
    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
201
202
    item_set = gb.ItemSetDict(
        {
203
204
205
206
            "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
            "user:follow:user": gb.ItemSet(
                (node_pairs_follow, neg_dsts_follow)
            ),
207
208
        }
    )
209
    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
210
211
212
213
214

    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

215
    # Single iterable with invalid length.
216
217
218
219
220
221
222
223
224
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)

225
    # Tuple of iterables with invalid length.
226
227
    item_set = gb.ItemSetDict(
        {
228
229
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
230
231
232
233
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)
Rhett Ying's avatar
Rhett Ying committed
234
235


236
237
238
239
def test_ItemSetDict_iteration_seed_nodes():
    # Node IDs.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(5, 10)
Rhett Ying's avatar
Rhett Ying committed
240
    ids = {
241
242
        "user": gb.ItemSet(user_ids, names="seed_nodes"),
        "item": gb.ItemSet(item_ids, names="seed_nodes"),
Rhett Ying's avatar
Rhett Ying committed
243
244
245
246
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
247
    item_set = gb.ItemSetDict(ids)
248
    assert item_set.names == ("seed_nodes",)
Rhett Ying's avatar
Rhett Ying committed
249
250
251
252
253
254
255
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def test_ItemSetDict_iteration_seed_nodes_labels():
    # Node IDs and labels.
    user_ids = torch.arange(0, 5)
    user_labels = torch.randint(0, 3, (5,))
    item_ids = torch.arange(5, 10)
    item_labels = torch.randint(0, 3, (5,))
    ids_labels = {
        "user": gb.ItemSet(
            (user_ids, user_labels), names=("seed_nodes", "labels")
        ),
        "item": gb.ItemSet(
            (item_ids, item_labels), names=("seed_nodes", "labels")
        ),
    }
    chained_ids = []
    for key, value in ids_labels.items():
        chained_ids += [(key, v) for v in value]
    item_set = gb.ItemSetDict(ids_labels)
    assert item_set.names == ("seed_nodes", "labels")
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]


def test_ItemSetDict_iteration_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
283
    # Node pairs.
284
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
285
    node_pairs_dict = {
286
287
        "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
        "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
Rhett Ying's avatar
Rhett Ying committed
288
289
290
291
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
292
    item_set = gb.ItemSetDict(node_pairs_dict)
293
    assert item_set.names == ("node_pairs",)
Rhett Ying's avatar
Rhett Ying committed
294
295
296
297
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
298
        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
Rhett Ying's avatar
Rhett Ying committed
299
300


301
def test_ItemSetDict_iteration_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
302
    # Node pairs and labels
303
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
Rhett Ying's avatar
Rhett Ying committed
304
    labels = torch.randint(0, 3, (5,))
305
306
307
308
309
310
311
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
Rhett Ying's avatar
Rhett Ying committed
312
313
    }
    expected_data = []
314
    for key, value in node_pairs_labels.items():
Rhett Ying's avatar
Rhett Ying committed
315
        expected_data += [(key, v) for v in value]
316
317
    item_set = gb.ItemSetDict(node_pairs_labels)
    assert item_set.names == ("node_pairs", "labels")
Rhett Ying's avatar
Rhett Ying committed
318
319
320
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
321
322
323
324
325
326
327
328
329
330
331
332
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert item[key][1] == value[1]


def test_ItemSetDict_iteration_node_pairs_neg_dsts():
    # Node pairs and negative destinations.
    node_pairs = torch.arange(0, 10).reshape(-1, 2)
    neg_dsts = torch.arange(10, 25).reshape(-1, 3)
    node_pairs_neg_dsts = {
        "user:like:item": gb.ItemSet(
333
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
334
335
        ),
        "user:follow:user": gb.ItemSet(
336
            (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
337
        ),
Rhett Ying's avatar
Rhett Ying committed
338
339
    }
    expected_data = []
340
    for key, value in node_pairs_neg_dsts.items():
Rhett Ying's avatar
Rhett Ying committed
341
        expected_data += [(key, v) for v in value]
342
    item_set = gb.ItemSetDict(node_pairs_neg_dsts)
343
    assert item_set.names == ("node_pairs", "negative_dsts")
Rhett Ying's avatar
Rhett Ying committed
344
345
346
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
347
348
349
350
        key, value = expected_data[i]
        assert key in item
        assert torch.equal(item[key][0], value[0])
        assert torch.equal(item[key][1], value[1])