graph_cases.py 3.96 KB
Newer Older
1
from collections import defaultdict
2
import backend as F
3
import dgl
4
import numpy as np
5
import networkx as nx
6
import numpy as np
7
import scipy.sparse as ssp
8
import backend as F
9
10
11
12
13
14
15

case_registry = defaultdict(list)

def register_case(labels):
    def wrapper(fn):
        for lbl in labels:
            case_registry[lbl].append(fn)
16
        fn.__labels__ = labels
17
18
19
        return fn
    return wrapper

20
21
def get_cases(labels=None, exclude=[]):
    """Get all graph instances of the given labels."""
22
23
24
25
26
    cases = set()
    if labels is None:
        # get all the cases
        labels = case_registry.keys()
    for lbl in labels:
27
28
29
        for case in case_registry[lbl]:
            if not any([l in exclude for l in case.__labels__]):
                cases.add(case)
30
31
    return [fn() for fn in cases]

32
@register_case(['bipartite', 'zero-degree'])
33
def bipartite1():
34
35
    return dgl.heterograph({('_U', '_E', '_V'): ([0, 0, 0, 2, 2, 3],
                                                 [0, 1, 4, 1, 4, 3])})
36

37
@register_case(['bipartite'])
38
def bipartite_full():
39
40
    return dgl.heterograph({('_U', '_E', '_V'): ([0, 0, 0, 0, 1, 1, 1, 1],
                                                 [0, 1, 2, 3, 0, 1, 2, 3])})
41

42
43
44
45
46
@register_case(['homo'])
def graph0():
    return dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
                      [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]))

47
48
@register_case(['homo', 'zero-degree', 'homo-zero-degree'])
def bipartite1():
49
    return dgl.graph(([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3]))
50

51
52
53
@register_case(['homo', 'has_feature'])
def graph1():
    g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
54
                   [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]), device=F.cpu())
55
56
57
58
59
60
61
62
    g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
    return g

@register_case(['hetero', 'has_feature'])
def heterograph0():
    g = dgl.heterograph({
        ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
63
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])}, device=F.cpu())
64
65
66
67
68
69
70
71
72
73
    g.nodes['user'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('user'), 3)), F.cpu())
    g.nodes['game'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('game'), 2)), F.cpu())
    g.nodes['developer'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('developer'), 3)), F.cpu())
    g.edges['plays'].data['h'] = F.copy_to(F.randn((g.number_of_edges('plays'), 1)), F.cpu())
    g.edges['develops'].data['h'] = F.copy_to(F.randn((g.number_of_edges('develops'), 5)), F.cpu())
    return g


@register_case(['batched', 'homo'])
def batched_graph0():
74
75
76
    g1 = dgl.add_self_loop(dgl.graph(([0, 1, 2], [1, 2, 3])))
    g2 = dgl.add_self_loop(dgl.graph(([1, 1], [2, 0])))
    g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
77
78
79
80
81
    return dgl.batch([g1, g2, g3])

@register_case(['block', 'bipartite', 'block-biparitite'])
def block_graph0():
    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
82
    g = g.to(F.cpu())
83
84
    return dgl.to_block(g)

85
@register_case(['block'])
86
87
88
89
90
def block_graph1():
    g = dgl.heterograph({
            ('user', 'plays', 'game') : ([0, 1, 2], [1, 1, 0]),
            ('user', 'likes', 'game') : ([1, 2, 3], [0, 0, 2]),
            ('store', 'sells', 'game') : ([0, 1, 1], [0, 1, 2]),
91
        }, device=F.cpu())
92
    return dgl.to_block(g)
93

94
95
96
97
98
@register_case(['clique'])
def clique():
    g = dgl.graph(([0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]))
    return g

99
100
101
102
def random_dglgraph(size):
    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))

def random_graph(size):
103
    return dgl.from_networkx(nx.erdos_renyi_graph(size, 0.3))
104
105

def random_bipartite(size_src, size_dst):
106
107
    return dgl.bipartite_from_scipy(ssp.random(size_src, size_dst, 0.1),
                                    utype='_U', etype='_E', vtype='V', )
108
109

def random_block(size):
110
    g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
111
    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))