models.py 6.51 KB
Newer Older
1
2
3
import copy
from functools import partial

Chen Sirui's avatar
Chen Sirui committed
4
5
6
import torch
import torch.nn as nn
from torch.nn import functional as F
7
8

import dgl
Chen Sirui's avatar
Chen Sirui committed
9
import dgl.function as fn
10
import dgl.nn as dglnn
Chen Sirui's avatar
Chen Sirui committed
11
12
13
14
15
16
17
18
19
20
21


class MLP(nn.Module):
    def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        layer = nn.Linear(hidden, out_feats)
        nn.init.normal_(layer.weight, std=0.1)
        nn.init.zeros_(layer.bias)
        self.layers.append(nn.Linear(in_feats, hidden))
        if num_layers > 2:
22
            for i in range(1, num_layers - 1):
Chen Sirui's avatar
Chen Sirui committed
23
24
25
26
27
28
29
30
31
32
                layer = nn.Linear(hidden, hidden)
                nn.init.normal_(layer.weight, std=0.1)
                nn.init.zeros_(layer.bias)
                self.layers.append(layer)
        layer = nn.Linear(hidden, out_feats)
        nn.init.normal_(layer.weight, std=0.1)
        nn.init.zeros_(layer.bias)
        self.layers.append(layer)

    def forward(self, x):
33
        for l in range(len(self.layers) - 1):
Chen Sirui's avatar
Chen Sirui committed
34
35
36
37
38
39
40
            x = self.layers[l](x)
            x = F.relu(x)
        x = self.layers[-1](x)
        return x


class PrepareLayer(nn.Module):
41
    """
Chen Sirui's avatar
Chen Sirui committed
42
43
44
45
46
47
48
49
50
    Generate edge feature for the model input preparation:
    as well as do the normalization work.
    Parameters
    ==========
    node_feats : int
        Number of node features

    stat : dict
        dictionary which represent the statistics needed for normalization
51
    """
Chen Sirui's avatar
Chen Sirui committed
52
53
54
55
56
57
58
59

    def __init__(self, node_feats, stat):
        super(PrepareLayer, self).__init__()
        self.node_feats = node_feats
        # stat {'median':median,'max':max,'min':min}
        self.stat = stat

    def normalize_input(self, node_feature):
60
61
62
        return (node_feature - self.stat["median"]) * (
            2 / (self.stat["max"] - self.stat["min"])
        )
Chen Sirui's avatar
Chen Sirui committed
63
64
65
66

    def forward(self, g, node_feature):
        with g.local_scope():
            node_feature = self.normalize_input(node_feature)
67
68
69
            g.ndata["feat"] = node_feature  # Only dynamic feature
            g.apply_edges(fn.u_sub_v("feat", "feat", "e"))
            edge_feature = g.edata["e"]
Chen Sirui's avatar
Chen Sirui committed
70
71
72
73
            return node_feature, edge_feature


class InteractionNet(nn.Module):
74
    """
Chen Sirui's avatar
Chen Sirui committed
75
76
77
78
79
80
81
82
83
84
    Simple Interaction Network
    One Layer interaction network for stellar multi-body problem simulation,
    it has the ability to simulate number of body motion no more than 12
    Parameters
    ==========
    node_feats : int
        Number of node features

    stat : dict
        Statistcics for Denormalization
85
    """
Chen Sirui's avatar
Chen Sirui committed
86
87
88
89
90
91
92
93

    def __init__(self, node_feats, stat):
        super(InteractionNet, self).__init__()
        self.node_feats = node_feats
        self.stat = stat
        edge_fn = partial(MLP, num_layers=5, hidden=150)
        node_fn = partial(MLP, num_layers=2, hidden=100)

94
95
96
97
98
99
100
101
102
        self.in_layer = InteractionLayer(
            node_feats - 3,  # Use velocity only
            node_feats,
            out_node_feats=2,
            out_edge_feats=50,
            edge_fn=edge_fn,
            node_fn=node_fn,
            mode="n_n",
        )
Chen Sirui's avatar
Chen Sirui committed
103
104
105

    # Denormalize Velocity only
    def denormalize_output(self, out):
106
107
108
109
        return (
            out * (self.stat["max"][3:5] - self.stat["min"][3:5]) / 2
            + self.stat["median"][3:5]
        )
Chen Sirui's avatar
Chen Sirui committed
110
111
112
113

    def forward(self, g, n_feat, e_feat, global_feats, relation_feats):
        with g.local_scope():
            out_n, out_e = self.in_layer(
114
115
                g, n_feat, e_feat, global_feats, relation_feats
            )
Chen Sirui's avatar
Chen Sirui committed
116
117
118
119
120
            out_n = self.denormalize_output(out_n)
            return out_n, out_e


class InteractionLayer(nn.Module):
121
    """
Chen Sirui's avatar
Chen Sirui committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    Implementation of single layer of interaction network
    Parameters
    ==========
    in_node_feats : int
        Number of node features

    in_edge_feats : int
        Number of edge features

    out_node_feats : int
        Number of node feature after one interaction

    out_edge_feats : int
        Number of edge features after one interaction

    global_feats : int
        Number of global features used as input

    relate_feats : int
        Feature related to the relation between object themselves

    edge_fn : torch.nn.Module
        Function to update edge feature in message generation

    node_fn : torch.nn.Module
        Function to update node feature in message aggregation

    mode : str
150
        Type of message should the edge carry
Chen Sirui's avatar
Chen Sirui committed
151
152
        nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.
        n_n : [src_feat-edge_feat] node feature subtract from each other.
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    """

    def __init__(
        self,
        in_node_feats,
        in_edge_feats,
        out_node_feats,
        out_edge_feats,
        global_feats=1,
        relate_feats=1,
        edge_fn=nn.Linear,
        node_fn=nn.Linear,
        mode="nne",
    ):  # 'n_n'
Chen Sirui's avatar
Chen Sirui committed
167
168
169
170
171
172
173
        super(InteractionLayer, self).__init__()
        self.in_node_feats = in_node_feats
        self.in_edge_feats = in_edge_feats
        self.out_edge_feats = out_edge_feats
        self.out_node_feats = out_node_feats
        self.mode = mode
        # MLP for message passing
174
175
176
177
178
179
180
181
182
183
184
185
186
        input_shape = (
            2 * self.in_node_feats + self.in_edge_feats
            if mode == "nne"
            else self.in_edge_feats + relate_feats
        )
        self.edge_fn = edge_fn(
            input_shape, self.out_edge_feats
        )  # 50 in IN paper

        self.node_fn = node_fn(
            self.in_node_feats + self.out_edge_feats + global_feats,
            self.out_node_feats,
        )
Chen Sirui's avatar
Chen Sirui committed
187
188
189

    # Should be done by apply edge
    def update_edge_fn(self, edges):
190
191
192
193
194
        x = torch.cat(
            [edges.src["feat"], edges.dst["feat"], edges.data["feat"]], dim=1
        )
        ret = F.relu(self.edge_fn(x)) if self.mode == "nne" else self.edge_fn(x)
        return {"e": ret}
Chen Sirui's avatar
Chen Sirui committed
195
196
197

    # Assume agg comes from build in reduce
    def update_node_fn(self, nodes):
198
199
200
        x = torch.cat([nodes.data["feat"], nodes.data["agg"]], dim=1)
        ret = F.relu(self.node_fn(x)) if self.mode == "nne" else self.node_fn(x)
        return {"n": ret}
Chen Sirui's avatar
Chen Sirui committed
201
202
203

    def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):
        # print(node_feats.shape,global_feats.shape)
204
205
206
        g.ndata["feat"] = torch.cat([node_feats, global_feats], dim=1)
        g.edata["feat"] = torch.cat([edge_feats, relation_feats], dim=1)
        if self.mode == "nne":
Chen Sirui's avatar
Chen Sirui committed
207
208
            g.apply_edges(self.update_edge_fn)
        else:
209
            g.edata["e"] = self.edge_fn(g.edata["feat"])
Chen Sirui's avatar
Chen Sirui committed
210

211
212
213
214
        g.update_all(
            fn.copy_e("e", "msg"), fn.sum("msg", "agg"), self.update_node_fn
        )
        return g.ndata["n"], g.edata["e"]