mx.py 3.35 KB
Newer Older
zzhang-cn's avatar
zzhang-cn committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import networkx as nx
from networkx.classes.graph import Graph

# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training

class mx_Graph(Graph):
    def __init__(self, *args, **kargs):
        super(mx_Graph, self).__init__(*args, **kargs)
        self.set_msg_func()
        self.set_update_func()
        self.set_readout_func()
        self.init_reprs()

    def init_reprs(self, h_init=None):
        for n in self.nodes:
            self.set_repr(n, h_init)

    def set_repr(self, u, h_u, name=None):
        assert u in self.nodes
        if name == None:
            self.add_node(u, h=h_u)
        else:
            self.add_node(u, name=h_u)

    def get_repr(self, u, name=None):
        assert u in self.nodes
        if name == None:
            return self.nodes[u]['h']
        else:
            return self.nodes[u][name]

    def set_msg_func(self, func=None, u=None):
        """Function that gathers messages from neighbors"""
        def _default_msg_func(u):
            assert u in self.nodes
            msg_gathered = 0
            for v in self.adj[u]:
                x = self.get_repr(v)
                if x is not None:
                    msg_gathered += x
            return msg_gathered

        # TODO: per node message function
        # TODO: 'sum' should be a separate function
        if func == None:
            self.msg_func = _default_msg_func
        else:
            self.msg_func = func

    def set_update_func(self, func=None, u=None):
        """
        Update function upon receiving an aggregate
        message from a node's neighbor
        """
        def _default_update_func(u, m):
            h_new = self.nodes[u]['h'] + m
            self.set_repr(u, h_new)

        # TODO: per node update function
        if func == None:
            self.update_func = _default_update_func
        else:
            self.update_func = func

    def set_readout_func(self, func=None):
        """Readout function of the whole graph"""
        def _default_readout_func():
            readout = 0
            for n in self.nodes:
                readout += self.nodes[n]['h']
            return readout

        if func == None:
            self.readout_func = _default_readout_func
        else:
            self.readout_func = func

    def readout(self):
        return self.readout_func()

    def update_to(self, u):
        """Pull messages from 1-step away neighbors of u"""
        assert u in self.nodes
        m = self.msg_func(u=u)
        self.update_func(u, m)

    def update_from(self, u):
        """Update u's 1-step away neighbors"""
        assert u in self.nodes
        # TODO: this asks v to pull from nodes other than
        # TODO: u, is this a good thing?
        for v in self.adj[u]:
            self.update_to(v)

    def print_all(self):
        for n in self.nodes:
            print(n, self.nodes[n])
        print()

if __name__ == '__main__':
    tg = nx.path_graph(10)
    g = mx_Graph(tg)
    g.print_all()

    tr = nx.balanced_tree(2, 3)
    m_tr = mx_Graph(tr)
    m_tr.print_all()

    g = mx_Graph(nx.path_graph(3))

    for n in g:
        g.set_repr(n, int(n) + 10)
    g.print_all()
    print(g.readout())

    print("before update:\t", g.nodes[0])
    g.update_to(0)
    print('after update:\t', g.nodes[0])
    g.print_all()

    print(g.readout())