import torch.nn as nn import dgl import dgl.function as fn from modules.initializers import GlorotOrthogonal class OutputBlock(nn.Module): def __init__(self, emb_size, num_radial, num_dense, num_targets, activation=None, output_init=nn.init.zeros_): super(OutputBlock, self).__init__() self.activation = activation self.output_init = output_init self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False) self.dense_layers = nn.ModuleList([ nn.Linear(emb_size, emb_size) for _ in range(num_dense) ]) self.dense_final = nn.Linear(emb_size, num_targets, bias=False) self.reset_params() def reset_params(self): GlorotOrthogonal(self.dense_rbf.weight) for layer in self.dense_layers: GlorotOrthogonal(layer.weight) self.output_init(self.dense_final.weight) def forward(self, g): with g.local_scope(): g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf']) g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't')) for layer in self.dense_layers: g.ndata['t'] = layer(g.ndata['t']) if self.activation is not None: g.ndata['t'] = self.activation(g.ndata['t']) g.ndata['t'] = self.dense_final(g.ndata['t']) return dgl.readout_nodes(g, 't')