test_integration.py 8.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dgl
import dgl.graphbolt as gb
import dgl.sparse as dglsp
import torch


def test_integration_link_prediction():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
41
    graph = gb.from_fused_csc(indptr, indices)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
    datapipe = datapipe.sample_uniform_negative(graph, 1)
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
58
    datapipe = datapipe.to_dgl()
59
    dataloader = gb.DataLoader(
60
61
62
63
        datapipe,
    )
    expected = [
        str(
64
65
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]),
                                  tensor([2, 3, 3, 1])),
66
67
68
69
70
71
72
             output_nodes=None,
             node_features={'feat': tensor([[0.5160, 0.2486],
                                    [0.8672, 0.2276],
                                    [0.6172, 0.7865],
                                    [0.2109, 0.1089],
                                    [0.9634, 0.2294],
                                    [0.5503, 0.8223]])},
73
             negative_node_pairs=(tensor([0, 1, 1, 1]),
74
                                  tensor([4, 4, 1, 4])),
75
             labels=None,
76
             input_nodes=None,
77
78
             edge_features=[{},
                            {}],
79
             blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
80
                     Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)],
81
          )"""
82
83
        ),
        str(
84
85
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]),
                                  tensor([0, 0, 1, 1])),
86
87
88
89
90
91
             output_nodes=None,
             node_features={'feat': tensor([[0.8672, 0.2276],
                                    [0.5503, 0.8223],
                                    [0.9634, 0.2294],
                                    [0.5160, 0.2486],
                                    [0.6172, 0.7865]])},
92
             negative_node_pairs=(tensor([0, 1, 1, 2]),
93
                                  tensor([1, 1, 3, 4])),
94
             labels=None,
95
             input_nodes=None,
96
97
             edge_features=[{},
                            {}],
98
99
             blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2),
                     Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],
100
          )"""
101
102
        ),
        str(
103
104
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1]),
                                  tensor([0, 0])),
105
106
             output_nodes=None,
             node_features={'feat': tensor([[0.5160, 0.2486],
107
                                    [0.5503, 0.8223]])},
108
             negative_node_pairs=(tensor([0, 1]),
109
                                  tensor([0, 0])),
110
             labels=None,
111
             input_nodes=None,
112
113
             edge_features=[{},
                            {}],
114
115
             blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
                     Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
116
          )"""
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)


def test_integration_node_classification():
    torch.manual_seed(926)

    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])

    matrix_a = dglsp.from_csc(indptr, indices)
    node_pairs = torch.t(torch.stack(matrix_a.coo()))
    node_feature_data = torch.tensor(
        [
            [0.9634, 0.2294],
            [0.6172, 0.7865],
            [0.2109, 0.1089],
            [0.8672, 0.2276],
            [0.5503, 0.8223],
            [0.5160, 0.2486],
        ]
    )
    edge_feature_data = torch.tensor(
        [
            [0.5123, 0.1709, 0.6150],
            [0.1476, 0.1902, 0.1314],
            [0.2582, 0.5203, 0.6228],
            [0.3708, 0.7631, 0.2683],
            [0.2126, 0.7878, 0.7225],
            [0.7885, 0.3414, 0.5485],
            [0.4088, 0.8200, 0.1851],
            [0.0056, 0.9469, 0.4432],
            [0.8972, 0.7511, 0.3617],
            [0.5773, 0.2199, 0.3366],
        ]
    )

    item_set = gb.ItemSet(node_pairs, names="node_pairs")
157
    graph = gb.from_fused_csc(indptr, indices)
158
159
160
161
162
163
164
165
166
167
168
169
170
171

    node_feature = gb.TorchBasedFeature(node_feature_data)
    edge_feature = gb.TorchBasedFeature(edge_feature_data)
    features = {
        ("node", None, "feat"): node_feature,
        ("edge", None, "feat"): edge_feature,
    }
    feature_store = gb.BasicFeatureStore(features)
    datapipe = gb.ItemSampler(item_set, batch_size=4)
    fanouts = torch.LongTensor([1])
    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
    datapipe = datapipe.fetch_feature(
        feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
    )
172
    datapipe = datapipe.to_dgl()
173
    dataloader = gb.DataLoader(
174
175
176
177
        datapipe,
    )
    expected = [
        str(
178
179
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]),
                                  tensor([2, 3, 3, 1])),
180
181
182
183
184
             output_nodes=None,
             node_features={'feat': tensor([[0.5160, 0.2486],
                                    [0.8672, 0.2276],
                                    [0.6172, 0.7865],
                                    [0.2109, 0.1089],
185
                                    [0.5503, 0.8223]])},
186
187
             negative_node_pairs=None,
             labels=None,
188
             input_nodes=None,
189
190
             edge_features=[{},
                            {}],
191
192
             blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
                     Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
193
          )"""
194
195
        ),
        str(
196
197
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]),
                                  tensor([0, 0, 1, 1])),
198
199
200
201
202
203
             output_nodes=None,
             node_features={'feat': tensor([[0.8672, 0.2276],
                                    [0.5503, 0.8223],
                                    [0.9634, 0.2294]])},
             negative_node_pairs=None,
             labels=None,
204
             input_nodes=None,
205
206
             edge_features=[{},
                            {}],
207
208
             blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
                     Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
209
          )"""
210
211
        ),
        str(
212
213
            """DGLMiniBatch(positive_node_pairs=(tensor([0, 1]),
                                  tensor([0, 0])),
214
215
216
217
218
219
             output_nodes=None,
             node_features={'feat': tensor([[0.5160, 0.2486],
                                    [0.5503, 0.8223],
                                    [0.9634, 0.2294]])},
             negative_node_pairs=None,
             labels=None,
220
             input_nodes=None,
221
222
             edge_features=[{},
                            {}],
223
224
             blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
                     Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
225
          )"""
226
227
228
229
        ),
    ]
    for step, data in enumerate(dataloader):
        assert expected[step] == str(data), print(data)