graph_cases.py 6.32 KB
Newer Older
1
from collections import defaultdict
2

3
import backend as F
4
5
import dgl
import networkx as nx
6
import numpy as np
7
import scipy.sparse as ssp
8
9
10

case_registry = defaultdict(list)

11

12
13
14
15
def register_case(labels):
    def wrapper(fn):
        for lbl in labels:
            case_registry[lbl].append(fn)
16
        fn.__labels__ = labels
17
        return fn
18

19
20
    return wrapper

21

22
23
def get_cases(labels=None, exclude=[]):
    """Get all graph instances of the given labels."""
24
25
26
27
28
    cases = set()
    if labels is None:
        # get all the cases
        labels = case_registry.keys()
    for lbl in labels:
29
30
31
        for case in case_registry[lbl]:
            if not any([l in exclude for l in case.__labels__]):
                cases.add(case)
32
33
    return [fn() for fn in cases]

34
35

@register_case(["bipartite", "zero-degree"])
36
def bipartite1():
37
38
39
40
    return dgl.heterograph(
        {("_U", "_E", "_V"): ([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3])}
    )

41

42
@register_case(["bipartite"])
43
def bipartite_full():
44
45
46
47
48
49
50
51
52
    return dgl.heterograph(
        {
            ("_U", "_E", "_V"): (
                [0, 0, 0, 0, 1, 1, 1, 1],
                [0, 1, 2, 3, 0, 1, 2, 3],
            )
        }
    )

53

54
@register_case(["homo"])
55
def graph0():
56
57
58
59
60
61
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        )
    )
62

63
64

@register_case(["homo", "zero-degree", "homo-zero-degree"])
65
def bipartite1():
66
    return dgl.graph(([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3]))
67

68
69

@register_case(["homo", "has_feature"])
70
def graph1():
71
72
73
74
75
76
77
    g = dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        device=F.cpu(),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
78
79
    g.ndata["h"] = F.copy_to(F.randn((g.num_nodes(), 2)), F.cpu())
    g.edata["w"] = F.copy_to(F.randn((g.num_edges(), 3)), F.cpu())
80
81
    return g

82
83

@register_case(["homo", "has_scalar_e_feature"])
84
def graph1():
85
86
87
88
89
90
91
    g = dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        device=F.cpu(),
    )
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
92
93
    g.ndata["h"] = F.copy_to(F.randn((g.num_nodes(), 2)), F.cpu())
    g.edata["scalar_w"] = F.copy_to(F.abs(F.randn((g.num_edges(),))), F.cpu())
94
95
    return g

96
97

@register_case(["homo", "row_sorted"])
98
def graph2():
99
100
101
102
103
104
105
106
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],
        ),
        row_sorted=True,
    )

107

108
@register_case(["homo", "row_sorted", "col_sorted"])
109
def graph3():
110
111
112
113
114
115
116
117
118
    return dgl.graph(
        (
            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
            [1, 4, 5, 2, 4, 7, 8, 9, 1, 4, 6, 0, 0, 1, 2, 3, 5],
        ),
        row_sorted=True,
        col_sorted=True,
    )

119

120
@register_case(["hetero", "has_feature"])
121
def heterograph0():
122
123
124
125
126
127
128
129
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        device=F.cpu(),
    )
    g.nodes["user"].data["h"] = F.copy_to(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
        F.randn((g.num_nodes("user"), 3)), F.cpu()
131
132
    )
    g.nodes["game"].data["h"] = F.copy_to(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
133
        F.randn((g.num_nodes("game"), 2)), F.cpu()
134
135
    )
    g.nodes["developer"].data["h"] = F.copy_to(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
        F.randn((g.num_nodes("developer"), 3)), F.cpu()
137
138
    )
    g.edges["plays"].data["h"] = F.copy_to(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
        F.randn((g.num_edges("plays"), 1)), F.cpu()
140
141
    )
    g.edges["develops"].data["h"] = F.copy_to(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
142
        F.randn((g.num_edges("develops"), 5)), F.cpu()
143
    )
144
145
146
    return g


147
@register_case(["batched", "homo"])
148
def batched_graph0():
149
150
151
    g1 = dgl.add_self_loop(dgl.graph(([0, 1, 2], [1, 2, 3])))
    g2 = dgl.add_self_loop(dgl.graph(([1, 1], [2, 0])))
    g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
152
153
    return dgl.batch([g1, g2, g3])

154
155

@register_case(["block", "bipartite", "block-bipartite"])
156
157
def block_graph0():
    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
158
    g = g.to(F.cpu())
159
160
    return dgl.to_block(g)

161
162

@register_case(["block"])
163
def block_graph1():
164
165
166
167
168
169
170
171
    g = dgl.heterograph(
        {
            ("user", "plays", "game"): ([0, 1, 2], [1, 1, 0]),
            ("user", "likes", "game"): ([1, 2, 3], [0, 0, 2]),
            ("store", "sells", "game"): ([0, 1, 1], [0, 1, 2]),
        },
        device=F.cpu(),
    )
172
    return dgl.to_block(g)
173

174
175

@register_case(["clique"])
176
177
178
179
def clique():
    g = dgl.graph(([0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]))
    return g

180

181
182
183
def random_dglgraph(size):
    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))

184

185
def random_graph(size):
186
    return dgl.from_networkx(nx.erdos_renyi_graph(size, 0.3))
187

188

189
def random_bipartite(size_src, size_dst):
190
191
192
193
194
195
196
    return dgl.bipartite_from_scipy(
        ssp.random(size_src, size_dst, 0.1),
        utype="_U",
        etype="_E",
        vtype="V",
    )

197
198

def random_block(size):
199
    g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
200
    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))
201

202
203

@register_case(["two_hetero_batch"])
204
def two_hetero_batch():
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 3], [0, 0, 1, 1]),
        }
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2], [0, 0, 1]),
        }
    )
219
220
    return [g1, g2]

221
222

@register_case(["two_hetero_batch"])
223
def two_hetero_batch_with_isolated_ntypes():
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 3], [0, 0, 1, 1]),
        },
        num_nodes_dict={"user": 4, "game": 2, "developer": 3, "platform": 2},
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "developer"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2], [0, 0, 1]),
        },
        num_nodes_dict={"user": 3, "game": 2, "developer": 3, "platform": 3},
    )
240
    return [g1, g2]