output_pp_block.py 1.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch.nn as nn
import dgl
import dgl.function as fn

from modules.initializers import GlorotOrthogonal

class OutputPPBlock(nn.Module):
    def __init__(self,
                 emb_size,
                 out_emb_size,
                 num_radial,
                 num_dense,
                 num_targets,
                 activation=None,
                 output_init=nn.init.zeros_,
                 extensive=True):
        super(OutputPPBlock, self).__init__()

        self.activation = activation
        self.output_init = output_init
        self.extensive = extensive
        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
        self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)
        self.dense_layers = nn.ModuleList([
            nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)
        ])
        self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)
        self.reset_params()
    
    def reset_params(self):
        GlorotOrthogonal(self.dense_rbf.weight)
        GlorotOrthogonal(self.up_projection.weight)
        for layer in self.dense_layers:
            GlorotOrthogonal(layer.weight)
        self.output_init(self.dense_final.weight)

    def forward(self, g):
        with g.local_scope():
            g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
            g_reverse = dgl.reverse(g, copy_edata=True)
            g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
            g.ndata['t'] = self.up_projection(g_reverse.ndata['t'])
            
            for layer in self.dense_layers:
                g.ndata['t'] = layer(g.ndata['t'])
                if self.activation is not None:
                    g.ndata['t'] = self.activation(g.ndata['t'])
            g.ndata['t'] = self.dense_final(g.ndata['t'])
            return dgl.readout_nodes(g, 't', op='sum' if self.extensive else 'mean')