graph_cases.py 6.39 KB
Newer Older
1
from collections import defaultdict
2

3
import backend as F
4
5
import dgl
import networkx as nx
6
import numpy as np
7
import scipy.sparse as ssp
8
9
10

case_registry = defaultdict(list)

11

12
13
14
15
def register_case(labels):
    def wrapper(fn):
        for lbl in labels:
            case_registry[lbl].append(fn)
16
        fn.__labels__ = labels
17
        return fn
18

19
20
    return wrapper

21

22
23
def get_cases(labels=None, exclude=[]):
    """Get all graph instances of the given labels."""
24
25
26
27
28
    cases = set()
    if labels is None:
        # get all the cases
        labels = case_registry.keys()
    for lbl in labels:
29
30
31
        for case in case_registry[lbl]:
            if not any([l in exclude for l in case.__labels__]):
                cases.add(case)
32
33
    return [fn() for fn in cases]

34
35

@register_case(["bipartite", "zero-degree"])
36
def bipartite1():
37
38
39
40
    return dgl.heterograph(
        {("_U", "_E", "_V"): ([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3])}
    )

41

42
@register_case(["bipartite"])
43
def bipartite_full():
44
45
46
47
48
49
50
51
52
    return dgl.heterograph(
        {
            ("_U", "_E", "_V"): (
                [0, 0, 0, 0, 1, 1, 1, 1],
                [0, 1, 2, 3, 0, 1, 2, 3],
            )
        }
    )

53

54
@register_case(["homo"])
55
def graph0():
56
57
58
59
60
61
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        )
    )
62

63
64

@register_case(["homo", "zero-degree", "homo-zero-degree"])
65
def bipartite1():
66
    return dgl.graph(([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3]))
67

68
69

@register_case(["homo", "has_feature"])
70
def graph1():
71
72
73
74
75
76
77
78
79
    g = dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        device=F.cpu(),
    )
    g.ndata["h"] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata["w"] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
80
81
    return g

82
83

@register_case(["homo", "has_scalar_e_feature"])
84
def graph1():
85
86
87
88
89
90
91
92
93
94
95
    g = dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        device=F.cpu(),
    )
    g.ndata["h"] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata["scalar_w"] = F.copy_to(
        F.abs(F.randn((g.number_of_edges(),))), F.cpu()
    )
96
97
    return g

98
99

@register_case(["homo", "row_sorted"])
100
def graph2():
101
102
103
104
105
106
107
108
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        row_sorted=True,
    )

109

110
@register_case(["homo", "row_sorted", "col_sorted"])
111
def graph3():
112
113
114
115
116
117
118
119
120
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [1, 4, 5, 2, 4, 7, 8, 9, 1, 4, 6, 0, 0, 1, 2, 3, 5],
        ),
        row_sorted=True,
        col_sorted=True,
    )

121

122
@register_case(["hetero", "has_feature"])
123
def heterograph0():
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        device=F.cpu(),
    )
    g.nodes["user"].data["h"] = F.copy_to(
        F.randn((g.number_of_nodes("user"), 3)), F.cpu()
    )
    g.nodes["game"].data["h"] = F.copy_to(
        F.randn((g.number_of_nodes("game"), 2)), F.cpu()
    )
    g.nodes["developer"].data["h"] = F.copy_to(
        F.randn((g.number_of_nodes("developer"), 3)), F.cpu()
    )
    g.edges["plays"].data["h"] = F.copy_to(
        F.randn((g.number_of_edges("plays"), 1)), F.cpu()
    )
    g.edges["develops"].data["h"] = F.copy_to(
        F.randn((g.number_of_edges("develops"), 5)), F.cpu()
    )
146
147
148
    return g


149
@register_case(["batched", "homo"])
150
def batched_graph0():
151
152
153
    g1 = dgl.add_self_loop(dgl.graph(([0, 1, 2], [1, 2, 3])))
    g2 = dgl.add_self_loop(dgl.graph(([1, 1], [2, 0])))
    g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
154
155
    return dgl.batch([g1, g2, g3])

156
157

@register_case(["block", "bipartite", "block-bipartite"])
158
159
def block_graph0():
    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
160
    g = g.to(F.cpu())
161
162
    return dgl.to_block(g)

163
164

@register_case(["block"])
165
def block_graph1():
166
167
168
169
170
171
172
173
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 2], [1, 1, 0]),
            ("user", "likes", "game"): ([1, 2, 3], [0, 0, 2]),
            ("store", "sells", "game"): ([0, 1, 1], [0, 1, 2]),
        },
        device=F.cpu(),
    )
174
    return dgl.to_block(g)
175

176
177

@register_case(["clique"])
178
179
180
181
def clique():
    g = dgl.graph(([0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]))
    return g

182

183
184
185
def random_dglgraph(size):
    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))

186

187
def random_graph(size):
188
    return dgl.from_networkx(nx.erdos_renyi_graph(size, 0.3))
189

190

191
def random_bipartite(size_src, size_dst):
192
193
194
195
196
197
198
    return dgl.bipartite_from_scipy(
        ssp.random(size_src, size_dst, 0.1),
        utype="_U",
        etype="_E",
        vtype="V",
    )

199
200

def random_block(size):
201
    g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
202
    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))
203

204
205

@register_case(["two_hetero_batch"])
206
def two_hetero_batch():
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 3], [0, 0, 1, 1]),
        }
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2], [0, 0, 1]),
        }
    )
221
222
    return [g1, g2]

223
224

@register_case(["two_hetero_batch"])
225
def two_hetero_batch_with_isolated_ntypes():
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 3], [0, 0, 1, 1]),
        },
        num_nodes_dict={"user": 4, "game": 2, "developer": 3, "platform": 2},
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2], [0, 0, 1]),
        },
        num_nodes_dict={"user": 3, "game": 2, "developer": 3, "platform": 3},
    )
242
    return [g1, g2]