import torch import torch.nn as nn import dgl.function as fn from modules.residual_layer import ResidualLayer from modules.initializers import GlorotOrthogonal class InteractionBlock(nn.Module): def __init__(self, emb_size, num_radial, num_spherical, num_bilinear, num_before_skip, num_after_skip, activation=None): super(InteractionBlock, self).__init__() self.activation = activation # Transformations of Bessel and spherical basis representations self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False) self.dense_sbf = nn.Linear(num_radial * num_spherical, num_bilinear, bias=False) # Dense transformations of input messages self.dense_ji = nn.Linear(emb_size, emb_size) self.dense_kj = nn.Linear(emb_size, emb_size) # Bilinear layer bilin_initializer = torch.empty((emb_size, num_bilinear, emb_size)).normal_(mean=0, std=2 / emb_size) self.W_bilin = nn.Parameter(bilin_initializer) # Residual layers before skip connection self.layers_before_skip = nn.ModuleList([ ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip) ]) self.final_before_skip = nn.Linear(emb_size, emb_size) # Residual layers after skip connection self.layers_after_skip = nn.ModuleList([ ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip) ]) self.reset_params() def reset_params(self): GlorotOrthogonal(self.dense_rbf.weight) GlorotOrthogonal(self.dense_sbf.weight) GlorotOrthogonal(self.dense_ji.weight) GlorotOrthogonal(self.dense_kj.weight) GlorotOrthogonal(self.final_before_skip.weight) def edge_transfer(self, edges): # Transform from Bessel basis to dence vector rbf = self.dense_rbf(edges.data['rbf']) # Initial transformation x_ji = self.dense_ji(edges.data['m']) x_kj = self.dense_kj(edges.data['m']) if self.activation is not None: x_ji = self.activation(x_ji) x_kj = self.activation(x_kj) # w: W * e_RBF \bigodot \sigma(W * m + b) return {'x_kj': x_kj * rbf, 'x_ji': x_ji} def msg_func(self, edges): sbf = self.dense_sbf(edges.data['sbf']) # Apply bilinear layer to interactions and basis function activation # [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128] x_kj = torch.einsum("wj,wl,ijl->wi", sbf, edges.src['x_kj'], self.W_bilin) return {'x_kj': x_kj} def forward(self, g, l_g): g.apply_edges(self.edge_transfer) # nodes correspond to edges and edges correspond to nodes in the original graphs # node: d, rbf, o, rbf_env, x_kj, x_ji for k, v in g.edata.items(): l_g.ndata[k] = v l_g.update_all(self.msg_func, fn.sum('x_kj', 'm_update')) for k, v in l_g.ndata.items(): g.edata[k] = v # Transformations before skip connection g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji'] for layer in self.layers_before_skip: g.edata['m_update'] = layer(g.edata['m_update']) g.edata['m_update'] = self.final_before_skip(g.edata['m_update']) if self.activation is not None: g.edata['m_update'] = self.activation(g.edata['m_update']) # Skip connection g.edata['m'] = g.edata['m'] + g.edata['m_update'] # Transformations after skip connection for layer in self.layers_after_skip: g.edata['m'] = layer(g.edata['m']) return g