import torch.nn as nn from dgl.nn import GINConv from dgl.base import dgl_warning class GIN(nn.Module): def __init__(self, data_info: dict, embed_size: int = -1, hidden_size=64, num_layers=3, aggregator_type='sum'): """Graph Isomophism Networks Parameters ---------- data_info : dict The information about the input dataset. embed_size : int The dimension of created embedding table. -1 means using original node embedding hidden_size : int Hidden size. num_layers : int Number of layers. aggregator_type : str Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'. """ super().__init__() self.data_info = data_info self.embed_size = embed_size self.conv_list = nn.ModuleList() self.num_layers = num_layers if embed_size > 0: self.embed = nn.Embedding(data_info["num_nodes"], embed_size) in_size = embed_size else: in_size = data_info["in_size"] for i in range(num_layers): input_dim = in_size if i == 0 else hidden_size mlp = nn.Sequential(nn.Linear(input_dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()) self.conv_list.append(GINConv(mlp, aggregator_type, 1e-5, True)) self.out_mlp = nn.Linear(hidden_size, self.out_size) def forward(self, graph, node_feat, edge_feat=None): if self.embed_size > 0: dgl_warning( "The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.", norepeat=True) h = self.embed.weight else: h = node_feat for i in range(self.num_layers): h = self.conv_list[i](graph, h) h = self.out_mlp(h) return h