graph_cases.py 5.49 KB
Newer Older
1
from collections import defaultdict
2
import backend as F
3
import dgl
4
import numpy as np
5
import networkx as nx
6
import numpy as np
7
import scipy.sparse as ssp
8
import backend as F
9
10
11
12
13
14
15

case_registry = defaultdict(list)

def register_case(labels):
    def wrapper(fn):
        for lbl in labels:
            case_registry[lbl].append(fn)
16
        fn.__labels__ = labels
17
18
19
        return fn
    return wrapper

20
21
def get_cases(labels=None, exclude=[]):
    """Get all graph instances of the given labels."""
22
23
24
25
26
    cases = set()
    if labels is None:
        # get all the cases
        labels = case_registry.keys()
    for lbl in labels:
27
28
29
        for case in case_registry[lbl]:
            if not any([l in exclude for l in case.__labels__]):
                cases.add(case)
30
31
    return [fn() for fn in cases]

32
@register_case(['bipartite', 'zero-degree'])
33
def bipartite1():
34
35
    return dgl.heterograph({('_U', '_E', '_V'): ([0, 0, 0, 2, 2, 3],
                                                 [0, 1, 4, 1, 4, 3])})
36

37
@register_case(['bipartite'])
38
def bipartite_full():
39
40
    return dgl.heterograph({('_U', '_E', '_V'): ([0, 0, 0, 0, 1, 1, 1, 1],
                                                 [0, 1, 2, 3, 0, 1, 2, 3])})
41

42
43
44
45
46
@register_case(['homo'])
def graph0():
    return dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
                      [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]))

47
48
@register_case(['homo', 'zero-degree', 'homo-zero-degree'])
def bipartite1():
49
    return dgl.graph(([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3]))
50

51
52
53
@register_case(['homo', 'has_feature'])
def graph1():
    g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
54
                   [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]), device=F.cpu())
55
56
57
58
    g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
    return g

59
60
61
62
63
64
65
66
@register_case(['homo', 'has_scalar_e_feature'])
def graph1():
    g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
                   [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]), device=F.cpu())
    g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata['scalar_w'] = F.copy_to(F.abs(F.randn((g.number_of_edges(),))), F.cpu())
    return g

67
68
69
70
@register_case(['hetero', 'has_feature'])
def heterograph0():
    g = dgl.heterograph({
        ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
71
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])}, device=F.cpu())
72
73
74
75
76
77
78
79
80
81
    g.nodes['user'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('user'), 3)), F.cpu())
    g.nodes['game'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('game'), 2)), F.cpu())
    g.nodes['developer'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('developer'), 3)), F.cpu())
    g.edges['plays'].data['h'] = F.copy_to(F.randn((g.number_of_edges('plays'), 1)), F.cpu())
    g.edges['develops'].data['h'] = F.copy_to(F.randn((g.number_of_edges('develops'), 5)), F.cpu())
    return g


@register_case(['batched', 'homo'])
def batched_graph0():
82
83
84
    g1 = dgl.add_self_loop(dgl.graph(([0, 1, 2], [1, 2, 3])))
    g2 = dgl.add_self_loop(dgl.graph(([1, 1], [2, 0])))
    g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
85
86
87
88
89
    return dgl.batch([g1, g2, g3])

@register_case(['block', 'bipartite', 'block-biparitite'])
def block_graph0():
    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
90
    g = g.to(F.cpu())
91
92
    return dgl.to_block(g)

93
@register_case(['block'])
94
95
96
97
98
def block_graph1():
    g = dgl.heterograph({
            ('user', 'plays', 'game') : ([0, 1, 2], [1, 1, 0]),
            ('user', 'likes', 'game') : ([1, 2, 3], [0, 0, 2]),
            ('store', 'sells', 'game') : ([0, 1, 1], [0, 1, 2]),
99
        }, device=F.cpu())
100
    return dgl.to_block(g)
101

102
103
104
105
106
@register_case(['clique'])
def clique():
    g = dgl.graph(([0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]))
    return g

107
108
109
110
def random_dglgraph(size):
    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))

def random_graph(size):
111
    return dgl.from_networkx(nx.erdos_renyi_graph(size, 0.3))
112
113

def random_bipartite(size_src, size_dst):
114
115
    return dgl.bipartite_from_scipy(ssp.random(size_src, size_dst, 0.1),
                                    utype='_U', etype='_E', vtype='V', )
116
117

def random_block(size):
118
    g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
119
    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

@register_case(['two_hetero_batch'])
def two_hetero_batch():
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'follows', 'developer'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
    })
    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'follows', 'developer'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
    })
    return [g1, g2]

@register_case(['two_hetero_batch'])
def two_hetero_batch_with_isolated_ntypes():
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'follows', 'developer'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
        }, num_nodes_dict={'user': 4, 'game': 2, 'developer': 3, 'platform': 2})
    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'follows', 'developer'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
        }, num_nodes_dict={'user': 3, 'game': 2, 'developer': 3, 'platform': 3})
    return [g1, g2]