test_itemset.py 8.38 KB
Newer Older
1
2
import re

Rhett Ying's avatar
Rhett Ying committed
3
import dgl
4
import pytest
Rhett Ying's avatar
Rhett Ying committed
5
import torch
6
from dgl import graphbolt as gb
Rhett Ying's avatar
Rhett Ying committed
7
from torch.testing import assert_close
8
9


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def test_ItemSet_names():
    # ItemSet with single name.
    item_set = gb.ItemSet(torch.arange(0, 5), names="seed_node")
    assert item_set.names == ("seed_node",)

    # ItemSet with multiple names.
    item_set = gb.ItemSet(
        (torch.arange(0, 5), torch.arange(5, 10)), names=("seed_node", "label")
    )
    assert item_set.names == ("seed_node", "label")

    # ItemSet with no name.
    item_set = gb.ItemSet(torch.arange(0, 5))
    assert item_set.names is None

    # ItemSet with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("Number of items (1) and names (2) must match."),
    ):
        _ = gb.ItemSet(torch.arange(0, 5), names=("seed_node", "label"))


def test_ItemSetDict_names():
    # ItemSetDict with single name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5), names="seed_node"),
            "item": gb.ItemSet(torch.arange(5, 10), names="seed_node"),
        }
    )
    assert item_set.names == ("seed_node",)

    # ItemSetDict with multiple names.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
                names=("seed_node", "label"),
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
                names=("seed_node", "label"),
            ),
        }
    )
    assert item_set.names == ("seed_node", "label")

    # ItemSetDict with no name.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(torch.arange(0, 5)),
            "item": gb.ItemSet(torch.arange(5, 10)),
        }
    )
    assert item_set.names is None

    # ItemSetDict with mismatched items and names.
    with pytest.raises(
        AssertionError,
        match=re.escape("All itemsets must have the same names."),
    ):
        _ = gb.ItemSetDict(
            {
                "user": gb.ItemSet(
                    (torch.arange(0, 5), torch.arange(5, 10)),
                    names=("seed_node", "label"),
                ),
                "item": gb.ItemSet(
                    (torch.arange(5, 10),), names=("seed_node",)
                ),
            }
        )


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def test_ItemSet_valid_length():
    # Single iterable.
    ids = torch.arange(0, 5)
    item_set = gb.ItemSet(ids)
    assert len(item_set) == 5

    # Tuple of iterables.
    node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
    item_set = gb.ItemSet(node_pairs)
    assert len(item_set) == 5


def test_ItemSet_invalid_length():
    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable.
    item_set = gb.ItemSet(InvalidLength())
    with pytest.raises(TypeError):
        _ = len(item_set)

    # Tuple of iterables.
    item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
    with pytest.raises(TypeError):
        _ = len(item_set)


def test_ItemSetDict_valid_length():
    # Single iterable.
    user_ids = torch.arange(0, 5)
    item_ids = torch.arange(0, 5)
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(user_ids),
            "item": gb.ItemSet(item_ids),
        }
    )
    assert len(item_set) == len(user_ids) + len(item_ids)

    # Tuple of iterables.
    like = (torch.arange(0, 5), torch.arange(0, 5))
    follow = (torch.arange(0, 5), torch.arange(5, 10))
    item_set = gb.ItemSetDict(
        {
130
131
            "user:like:item": gb.ItemSet(like),
            "user:follow:user": gb.ItemSet(follow),
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        }
    )
    assert len(item_set) == len(like[0]) + len(follow[0])


def test_ItemSetDict_invalid_length():
    class InvalidLength:
        def __iter__(self):
            return iter([0, 1, 2])

    # Single iterable.
    item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(InvalidLength()),
            "item": gb.ItemSet(InvalidLength()),
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)

    # Tuple of iterables.
    item_set = gb.ItemSetDict(
        {
155
156
            "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
            "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
157
158
159
160
        }
    )
    with pytest.raises(TypeError):
        _ = len(item_set)
Rhett Ying's avatar
Rhett Ying committed
161
162
163
164


def test_ItemSet_node_edge_ids():
    # Node or edge IDs.
165
    item_set = gb.ItemSet(torch.arange(0, 5))
Rhett Ying's avatar
Rhett Ying committed
166
167
168
169
170
171
172
    for i, item in enumerate(item_set):
        assert i == item.item()


def test_ItemSet_graphs():
    # Graphs.
    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
173
    item_set = gb.ItemSet(graphs)
Rhett Ying's avatar
Rhett Ying committed
174
175
176
177
178
179
180
    for i, item in enumerate(item_set):
        assert graphs[i] == item


def test_ItemSet_node_pairs():
    # Node pairs.
    node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
181
    item_set = gb.ItemSet(node_pairs)
Rhett Ying's avatar
Rhett Ying committed
182
183
184
185
186
187
188
189
190
    for i, (src, dst) in enumerate(item_set):
        assert node_pairs[0][i] == src
        assert node_pairs[1][i] == dst


def test_ItemSet_node_pairs_labels():
    # Node pairs and labels
    node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
    labels = torch.randint(0, 3, (5,))
191
    item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
Rhett Ying's avatar
Rhett Ying committed
192
193
194
195
196
197
198
199
200
201
202
    for i, (src, dst, label) in enumerate(item_set):
        assert node_pairs[0][i] == src
        assert node_pairs[1][i] == dst
        assert labels[i] == label


def test_ItemSet_head_tail_neg_tails():
    # Head, tail and negative tails.
    heads = torch.arange(0, 5)
    tails = torch.arange(5, 10)
    neg_tails = torch.arange(10, 20).reshape(5, 2)
203
    item_set = gb.ItemSet((heads, tails, neg_tails))
Rhett Ying's avatar
Rhett Ying committed
204
205
206
207
208
209
    for i, (head, tail, negs) in enumerate(item_set):
        assert heads[i] == head
        assert tails[i] == tail
        assert_close(neg_tails[i], negs)


210
def test_ItemSetDict_node_edge_ids():
Rhett Ying's avatar
Rhett Ying committed
211
212
    # Node or edge IDs
    ids = {
213
214
        "user:like:item": gb.ItemSet(torch.arange(0, 5)),
        "user:follow:user": gb.ItemSet(torch.arange(0, 5)),
Rhett Ying's avatar
Rhett Ying committed
215
216
217
218
    }
    chained_ids = []
    for key, value in ids.items():
        chained_ids += [(key, v) for v in value]
219
    item_set = gb.ItemSetDict(ids)
Rhett Ying's avatar
Rhett Ying committed
220
221
222
223
224
225
226
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert chained_ids[i][0] in item
        assert item[chained_ids[i][0]] == chained_ids[i][1]


227
def test_ItemSetDict_node_pairs():
Rhett Ying's avatar
Rhett Ying committed
228
229
230
    # Node pairs.
    node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
    node_pairs_dict = {
231
232
        "user:like:item": gb.ItemSet(node_pairs),
        "user:follow:user": gb.ItemSet(node_pairs),
Rhett Ying's avatar
Rhett Ying committed
233
234
235
236
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
237
    item_set = gb.ItemSetDict(node_pairs_dict)
Rhett Ying's avatar
Rhett Ying committed
238
239
240
241
242
243
244
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
        assert item[expected_data[i][0]] == expected_data[i][1]


245
def test_ItemSetDict_node_pairs_labels():
Rhett Ying's avatar
Rhett Ying committed
246
247
248
249
    # Node pairs and labels
    node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
    labels = torch.randint(0, 3, (5,))
    node_pairs_dict = {
250
251
        "user:like:item": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
        "user:follow:user": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
Rhett Ying's avatar
Rhett Ying committed
252
253
254
255
    }
    expected_data = []
    for key, value in node_pairs_dict.items():
        expected_data += [(key, v) for v in value]
256
    item_set = gb.ItemSetDict(node_pairs_dict)
Rhett Ying's avatar
Rhett Ying committed
257
258
259
260
261
262
263
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
        assert item[expected_data[i][0]] == expected_data[i][1]


264
def test_ItemSetDict_head_tail_neg_tails():
Rhett Ying's avatar
Rhett Ying committed
265
266
267
268
    # Head, tail and negative tails.
    heads = torch.arange(0, 5)
    tails = torch.arange(5, 10)
    neg_tails = torch.arange(10, 20).reshape(5, 2)
269
    item_set = gb.ItemSet((heads, tails, neg_tails))
Rhett Ying's avatar
Rhett Ying committed
270
    data_dict = {
271
272
        "user:like:item": gb.ItemSet((heads, tails, neg_tails)),
        "user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
Rhett Ying's avatar
Rhett Ying committed
273
274
275
276
    }
    expected_data = []
    for key, value in data_dict.items():
        expected_data += [(key, v) for v in value]
277
    item_set = gb.ItemSetDict(data_dict)
Rhett Ying's avatar
Rhett Ying committed
278
279
280
281
282
    for i, item in enumerate(item_set):
        assert len(item) == 1
        assert isinstance(item, dict)
        assert expected_data[i][0] in item
        assert_close(item[expected_data[i][0]], expected_data[i][1])