import torch import torch.nn as nn import torch.nn.functional as F class ElementWiseProductPredictor(nn.Module): def __init__(self, data_info: dict, hidden_size: int = 64, num_layers: int = 2, bias: bool = True): """Elementwise product model for edge scores Parameters ---------- data_info : dict The information about the input dataset. hidden_size : int Hidden size. num_layers : int Number of hidden layers. bias : bool Whether to use bias in the linaer layer. """ super(ElementWiseProductPredictor, self).__init__() lins_list = [] in_size, out_size = data_info["in_size"], data_info["out_size"] for i in range(num_layers): in_hiddnen = in_size if i == 0 else hidden_size out_hidden = hidden_size if i < num_layers - 1 else out_size lins_list.append(nn.Linear(in_hiddnen, out_hidden, bias=bias)) if i < num_layers - 1: lins_list.append(nn.ReLU()) self.linear = nn.Sequential(*lins_list) def forward(self, h_src, h_dst): h = h_src * h_dst h = self.linear(h) h = torch.sigmoid(h) return h