""" Recurrent Relational Network(RRN) module References: - Recurrent Relational Networks - Paper: https://arxiv.org/abs/1711.08028 - Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks """ import torch from torch import nn import dgl.function as fn class RRNLayer(nn.Module): def __init__(self, msg_layer, node_update_func, edge_drop): super(RRNLayer, self).__init__() self.msg_layer = msg_layer self.node_update_func = node_update_func self.edge_dropout = nn.Dropout(edge_drop) def forward(self, g): g.apply_edges(self.get_msg) g.edata['e'] = self.edge_dropout(g.edata['e']) g.update_all(message_func=fn.copy_e('e', 'msg'), reduce_func=fn.sum('msg', 'm')) g.apply_nodes(self.node_update) def get_msg(self, edges): e = torch.cat([edges.src['h'], edges.dst['h']], -1) e = self.msg_layer(e) return {'e': e} def node_update(self, nodes): return self.node_update_func(nodes) class RRN(nn.Module): def __init__(self, msg_layer, node_update_func, num_steps, edge_drop): super(RRN, self).__init__() self.num_steps = num_steps self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop) def forward(self, g, get_all_outputs=True): outputs = [] for _ in range(self.num_steps): self.rrn_layer(g) if get_all_outputs: outputs.append(g.ndata['h']) if get_all_outputs: outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim else: outputs = g.ndata['h'] # n_nodes x h_dim return outputs