import dgl import torch as th import numpy as np import torch.nn as nn import dgl.function as fn def _l1_dist(edges): # formula 2 ed = th.norm(edges.src['nd'] - edges.dst['nd'], 1, 1) return {'ed': ed} class CARESampler(dgl.dataloading.BlockSampler): def __init__(self, p, dists, num_layers): super().__init__(num_layers) self.p = p self.dists = dists def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs): with g.local_scope(): new_edges_masks = {} for etype in g.canonical_etypes: edge_mask = th.zeros(g.number_of_edges(etype)) # extract each node from dict because of single node type for node in seed_nodes: edges = g.in_edges(node, form='eid', etype=etype) num_neigh = th.ceil(g.in_degrees(node, etype=etype) * self.p[block_id][etype]).int().item() neigh_dist = self.dists[block_id][etype][edges] if neigh_dist.shape[0] > num_neigh: neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh] else: neigh_index = np.arange(num_neigh) edge_mask[edges[neigh_index]] = 1 new_edges_masks[etype] = edge_mask.bool() return dgl.edge_subgraph(g, new_edges_masks, relabel_nodes=False) def __len__(self): return self.num_layers class CAREConv(nn.Module): """One layer of CARE-GNN.""" def __init__(self, in_dim, out_dim, num_classes, edges, activation=None, step_size=0.02): super(CAREConv, self).__init__() self.activation = activation self.step_size = step_size self.in_dim = in_dim self.out_dim = out_dim self.num_classes = num_classes self.edges = edges self.linear = nn.Linear(self.in_dim, self.out_dim) self.MLP = nn.Linear(self.in_dim, self.num_classes) self.p = {} self.last_avg_dist = {} self.f = {} # indicate whether the RL converges self.cvg = {} for etype in edges: self.p[etype] = 0.5 self.last_avg_dist[etype] = 0 self.f[etype] = [] self.cvg[etype] = False def forward(self, g, feat): g.srcdata['h'] = feat # formula 8 hr = {} for etype in g.canonical_etypes: g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'hr'), etype=etype) hr[etype] = g.dstdata['hr'] if self.activation is not None: hr[etype] = self.activation(hr[etype]) # formula 9 using mean as inter-relation aggregator p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device) h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0) h_homo += feat[:g.number_of_dst_nodes()] if self.activation is not None: h_homo = self.activation(h_homo) return self.linear(h_homo) class CAREGNN(nn.Module): def __init__(self, in_dim, num_classes, hid_dim=64, edges=None, num_layers=2, activation=None, step_size=0.02): super(CAREGNN, self).__init__() self.in_dim = in_dim self.hid_dim = hid_dim self.num_classes = num_classes self.edges = edges self.num_layers = num_layers self.activation = activation self.step_size = step_size self.layers = nn.ModuleList() if self.num_layers == 1: # Single layer self.layers.append(CAREConv(self.in_dim, self.num_classes, self.num_classes, self.edges, activation=self.activation, step_size=self.step_size)) else: # Input layer self.layers.append(CAREConv(self.in_dim, self.hid_dim, self.num_classes, self.edges, activation=self.activation, step_size=self.step_size)) # Hidden layers with n - 2 layers for i in range(self.num_layers - 2): self.layers.append(CAREConv(self.hid_dim, self.hid_dim, self.num_classes, self.edges, activation=self.activation, step_size=self.step_size)) # Output layer self.layers.append(CAREConv(self.hid_dim, self.num_classes, self.num_classes, self.edges, activation=self.activation, step_size=self.step_size)) def forward(self, blocks, feat): # formula 4 sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata['feature'].float())) # Forward of n layers of CARE-GNN for block, layer in zip(blocks, self.layers): feat = layer(block, feat) return feat, sim def RLModule(self, graph, epoch, idx, dists): for i, layer in enumerate(self.layers): for etype in self.edges: if not layer.cvg[etype]: # formula 5 eid = graph.in_edges(idx, form='eid', etype=etype) avg_dist = th.mean(dists[i][etype][eid]) # formula 6 if layer.last_avg_dist[etype] < avg_dist: layer.p[etype] -= self.step_size layer.f[etype].append(-1) # avoid overflow, follow the author's implement if layer.p[etype] < 0: layer.p[etype] = 0.001 else: layer.p[etype] += self.step_size layer.f[etype].append(+1) if layer.p[etype] > 1: layer.p[etype] = 0.999 layer.last_avg_dist[etype] = avg_dist # formula 7 if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2: layer.cvg[etype] = True