graph.py 3.74 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
import networkx as nx
zzhang-cn's avatar
zzhang-cn committed
2
from networkx.classes.digraph import DiGraph
Gan Quan's avatar
Gan Quan committed
3

zzhang-cn's avatar
zzhang-cn committed
4
class dgl_Graph(DiGraph):
zzhang-cn's avatar
zzhang-cn committed
5
6
7
8
9
10
    '''
    Functions:
        - m_func: per edge (u, v), default is u['state']
        - u_func: per node u, default is RNN(m, u['state'])
    '''
    def __init__(self, *args, **kargs):
zzhang-cn's avatar
zzhang-cn committed
11
        super(dgl_Graph, self).__init__(*args, **kargs)
zzhang-cn's avatar
zzhang-cn committed
12
13
14
15
        self.m_func = DefaultMessageModule()
        self.u_func = DefaultUpdateModule()
        self.readout_func = DefaultReadoutModule()
        self.init_reprs()
Gan Quan's avatar
Gan Quan committed
16

zzhang-cn's avatar
zzhang-cn committed
17
18
19
    def init_reprs(self, h_init=None):
        for n in self.nodes:
            self.set_repr(n, h_init)
Gan Quan's avatar
Gan Quan committed
20

zzhang-cn's avatar
zzhang-cn committed
21
22
23
24
    def set_repr(self, u, h_u, name='state'):
        assert u in self.nodes
        kwarg = {name: h_u}
        self.add_node(u, **kwarg)
Gan Quan's avatar
Gan Quan committed
25

zzhang-cn's avatar
zzhang-cn committed
26
27
28
    def get_repr(self, u, name='state'):
        assert u in self.nodes
        return self.nodes[u][name]
Gan Quan's avatar
Gan Quan committed
29

zzhang-cn's avatar
zzhang-cn committed
30
31
    def _nodes_or_all(self, nodes='all'):
        return self.nodes() if nodes == 'all' else nodes
Gan Quan's avatar
Gan Quan committed
32

zzhang-cn's avatar
zzhang-cn committed
33
34
    def _edges_or_all(self, edges='all'):
        return self.edges() if edges == 'all' else edges
Gan Quan's avatar
Gan Quan committed
35
36

    def register_message_func(self, message_func, edges='all', batched=False):
zzhang-cn's avatar
zzhang-cn committed
37
38
39
40
41
        if edges == 'all':
            self.m_func = message_func
        else:
            for e in self.edges:
                self.edges[e]['m_func'] = message_func
Gan Quan's avatar
Gan Quan committed
42
43

    def register_update_func(self, update_func, nodes='all', batched=False):
zzhang-cn's avatar
zzhang-cn committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if nodes == 'all':
            self.u_func = update_func
        else:
            for n in nodes:
                self.node[n]['u_func'] = update_func

    def register_readout_func(self, readout_func):
        self.readout_func = readout_func

    def readout(self, nodes='all', **kwargs):
        nodes_state = []
        nodes = self._nodes_or_all(nodes)
        for n in nodes:
            nodes_state.append(self.get_repr(n))
        return self.readout_func(nodes_state, **kwargs)

    def sendto(self, u, v):
        """Compute message on edge u->v
        Args:
            u: source node
            v: destination node
        """
        f_msg = self.edges[(u, v)].get('m_func', self.m_func)
        m = f_msg(self.get_repr(u))
        self.edges[(u, v)]['msg'] = m

    def sendto_ebunch(self, ebunch):
        """Compute message on edge u->v
        Args:
            ebunch: a bunch of edges
        """
        #TODO: simplify the logics
        for u, v in ebunch:
            f_msg = self.edges[(u, v)].get('m_func', self.m_func)
            m = f_msg(self.get_repr(u))
            self.edges[(u, v)]['msg'] = m

    def recvfrom(self, u, nodes):
        """Update u by nodes
        Args:
            u: node to be updated
            nodes: nodes with pre-computed messages to u
        """
        m = [self.edges[(v, u)]['msg'] for v in nodes]
        f_update = self.nodes[u].get('u_func', self.u_func)
        x_new = f_update(self.get_repr(u), m)
        self.set_repr(u, x_new)

    def update_by_edge(self, e):
        u, v = e
        self.sendto(u, v)
        self.recvfrom(v, [u])

    def update_to(self, u):
        """Pull messages from 1-step away neighbors of u"""
        assert u in self.nodes

        for v in self.pred[u]:
            self.sendto(v, u)
        self.recvfrom(u, list(self.pred[u]))

    def update_from(self, u):
        """Update u's 1-step away neighbors"""
        assert u in self.nodes
        for v in self.succ[u]:
            self.update_to(v)

    def update_all_step(self):
        self.sendto_ebunch(self.edges)
        for u in self.nodes:
            self.recvfrom(u, list(self.pred[u]))
Gan Quan's avatar
Gan Quan committed
115

zzhang-cn's avatar
zzhang-cn committed
116
117
118
    def draw(self):
        from networkx.drawing.nx_agraph import graphviz_layout

zzhang-cn's avatar
zzhang-cn committed
119
120
121
122
123
124
125
        pos = graphviz_layout(self, prog='dot')
        nx.draw(self, pos, with_labels=True)

    def print_all(self):
        for n in self.nodes:
            print(n, self.nodes[n])
        print()