graph_cases.py 3.43 KB
Newer Older
1
from collections import defaultdict
2
import backend as F
3
import dgl
4
import numpy as np
5
import networkx as nx
6
import numpy as np
7
import scipy.sparse as ssp
8
import backend as F
9
10
11
12
13
14
15

case_registry = defaultdict(list)

def register_case(labels):
    def wrapper(fn):
        for lbl in labels:
            case_registry[lbl].append(fn)
16
        fn.__labels__ = labels
17
18
19
        return fn
    return wrapper

20
21
def get_cases(labels=None, exclude=[]):
    """Get all graph instances of the given labels."""
22
23
24
25
26
    cases = set()
    if labels is None:
        # get all the cases
        labels = case_registry.keys()
    for lbl in labels:
27
28
29
        for case in case_registry[lbl]:
            if not any([l in exclude for l in case.__labels__]):
                cases.add(case)
30
31
    return [fn() for fn in cases]

32
@register_case(['dglgraph', 'path'])
33
34
35
def dglgraph_path():
    return dgl.DGLGraph(nx.path_graph(5))

36
@register_case(['bipartite'])
37
38
39
def bipartite1():
    return dgl.bipartite([(0, 0), (0, 1), (0, 4), (2, 1), (2, 4), (3, 3)])

40
@register_case(['bipartite'])
41
42
def bipartite_full():
    return dgl.bipartite([(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)])
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@register_case(['homo'])
def graph0():
    return dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
                      [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]))

@register_case(['homo', 'has_feature'])
def graph1():
    g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
                   [4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]))
    g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
    g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
    return g

@register_case(['hetero', 'has_feature'])
def heterograph0():
    g = dgl.heterograph({
        ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])})
    g.nodes['user'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('user'), 3)), F.cpu())
    g.nodes['game'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('game'), 2)), F.cpu())
    g.nodes['developer'].data['h'] = F.copy_to(F.randn((g.number_of_nodes('developer'), 3)), F.cpu())
    g.edges['plays'].data['h'] = F.copy_to(F.randn((g.number_of_edges('plays'), 1)), F.cpu())
    g.edges['develops'].data['h'] = F.copy_to(F.randn((g.number_of_edges('develops'), 5)), F.cpu())
    return g


@register_case(['batched', 'homo'])
def batched_graph0():
    g1 = dgl.graph(([0, 1, 2], [1, 2, 3]))
    g2 = dgl.graph(([1, 1], [2, 0]))
    g3 = dgl.graph(([0], [1]))
    return dgl.batch([g1, g2, g3])

@register_case(['block', 'bipartite', 'block-biparitite'])
def block_graph0():
    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
    return dgl.to_block(g)

82
@register_case(['block'])
83
84
85
86
87
88
89
def block_graph1():
    g = dgl.heterograph({
            ('user', 'plays', 'game') : ([0, 1, 2], [1, 1, 0]),
            ('user', 'likes', 'game') : ([1, 2, 3], [0, 0, 2]),
            ('store', 'sells', 'game') : ([0, 1, 1], [0, 1, 2]),
        })
    return dgl.to_block(g)
90

91
92
93
94
95
96
97
98
def random_dglgraph(size):
    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))

def random_graph(size):
    return dgl.graph(nx.erdos_renyi_graph(size, 0.3))

def random_bipartite(size_src, size_dst):
    return dgl.bipartite(ssp.random(size_src, size_dst, 0.1))
99
100
101
102

def random_block(size):
    g = dgl.graph(nx.erdos_renyi_graph(size, 0.1))
    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))