import dgl import copy import torch from torch import nn from torch.nn.init import ones_, zeros_ from torch.nn import BatchNorm1d, Parameter from dgl.nn.pytorch.conv import GraphConv, SAGEConv class LayerNorm(nn.Module): def __init__(self, in_channels, eps=1e-5, affine=True): super().__init__() self.in_channels = in_channels self.eps = eps if affine: self.weight = Parameter(torch.Tensor(in_channels)) self.bias = Parameter(torch.Tensor(in_channels)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): ones_(self.weight) zeros_(self.bias) def forward(self, x, batch=None): device = x.device if batch is None: x = x - x.mean() out = x / (x.std(unbiased=False) + self.eps) else: batch_size = int(batch.max()) + 1 batch_idx = [batch == i for i in range(batch_size)] norm = torch.tensor([i.sum() for i in batch_idx], dtype=x.dtype).clamp_(min=1).to(device) norm = norm.mul_(x.size(-1)).view(-1, 1) tmp_list = [x[i] for i in batch_idx] mean = torch.concat([i.sum(0).unsqueeze(0) for i in tmp_list], dim=0).sum(dim=-1, keepdim=True).to(device) mean = mean / norm x = x - mean.index_select(0, batch.long()) var = torch.concat([(i * i).sum(0).unsqueeze(0) for i in tmp_list], dim=0).sum(dim=-1, keepdim=True).to(device) var = var / norm out = x / (var + self.eps).sqrt().index_select(0, batch.long()) if self.weight is not None and self.bias is not None: out = out * self.weight + self.bias return out def __repr__(self): return f'{self.__class__.__name__}({self.in_channels})' class MLP_Predictor(nn.Module): r"""MLP used for predictor. The MLP has one hidden layer. Args: input_size (int): Size of input features. output_size (int): Size of output features. hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`). """ def __init__(self, input_size, output_size, hidden_size=512): super().__init__() self.net = nn.Sequential( nn.Linear(input_size, hidden_size, bias=True), nn.PReLU(1), nn.Linear(hidden_size, output_size, bias=True) ) self.reset_parameters() def forward(self, x): return self.net(x) def reset_parameters(self): # kaiming_uniform for m in self.modules(): if isinstance(m, nn.Linear): m.reset_parameters() class GCN(nn.Module): def __init__(self, layer_sizes, batch_norm_mm=0.99): super(GCN, self).__init__() self.layers = nn.ModuleList() for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]): self.layers.append(GraphConv(in_dim, out_dim)) self.layers.append(BatchNorm1d(out_dim, momentum=batch_norm_mm)) self.layers.append(nn.PReLU()) def forward(self, g): x = g.ndata['feat'] for layer in self.layers: if isinstance(layer, GraphConv): x = layer(g, x) else: x = layer(x) return x def reset_parameters(self): for layer in self.layers: if hasattr(layer, 'reset_parameters'): layer.reset_parameters() class GraphSAGE_GCN(nn.Module): def __init__(self, layer_sizes): super().__init__() input_size, hidden_size, embedding_size = layer_sizes self.convs = nn.ModuleList([ SAGEConv(input_size, hidden_size, 'mean'), SAGEConv(hidden_size, hidden_size, 'mean'), SAGEConv(hidden_size, embedding_size, 'mean') ]) self.skip_lins = nn.ModuleList([ nn.Linear(input_size, hidden_size, bias=False), nn.Linear(input_size, hidden_size, bias=False), ]) self.layer_norms = nn.ModuleList([ LayerNorm(hidden_size), LayerNorm(hidden_size), LayerNorm(embedding_size), ]) self.activations = nn.ModuleList([ nn.PReLU(), nn.PReLU(), nn.PReLU(), ]) def forward(self, g): x = g.ndata['feat'] if 'batch' in g.ndata.keys(): batch = g.ndata['batch'] else: batch = None h1 = self.convs[0](g, x) h1 = self.layer_norms[0](h1, batch) h1 = self.activations[0](h1) x_skip_1 = self.skip_lins[0](x) h2 = self.convs[1](g, h1 + x_skip_1) h2 = self.layer_norms[1](h2, batch) h2 = self.activations[1](h2) x_skip_2 = self.skip_lins[1](x) ret = self.convs[2](g, h1 + h2 + x_skip_2) ret = self.layer_norms[2](ret, batch) ret = self.activations[2](ret) return ret def reset_parameters(self): for m in self.convs: m.reset_parameters() for m in self.skip_lins: m.reset_parameters() for m in self.activations: m.weight.data.fill_(0.25) for m in self.layer_norms: m.reset_parameters() class BGRL(nn.Module): r"""BGRL architecture for Graph representation learning. Args: encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks. predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection. .. note:: `encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized differently from the online network. """ def __init__(self, encoder, predictor): super(BGRL, self).__init__() # online network self.online_encoder = encoder self.predictor = predictor # target network self.target_encoder = copy.deepcopy(encoder) # reinitialize weights self.target_encoder.reset_parameters() # stop gradient for param in self.target_encoder.parameters(): param.requires_grad = False def trainable_parameters(self): r"""Returns the parameters that will be updated via an optimizer.""" return list(self.online_encoder.parameters()) + list(self.predictor.parameters()) @torch.no_grad() def update_target_network(self, mm): r"""Performs a momentum update of the target network's weights. Args: mm (float): Momentum used in moving average update. """ for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm) def forward(self, online_x, target_x): # forward online network online_y = self.online_encoder(online_x) # prediction online_q = self.predictor(online_y) # forward target network with torch.no_grad(): target_y = self.target_encoder(target_x).detach() return online_q, target_y def compute_representations(net, dataset, device): r"""Pre-computes the representations for the entire data. Returns: [torch.Tensor, torch.Tensor]: Representations and labels. """ net.eval() reps = [] labels = [] if len(dataset) == 1: g = dataset[0] g = dgl.add_self_loop(g) g = g.to(device) with torch.no_grad(): reps.append(net(g)) labels.append(g.ndata['label']) else: for g in dataset: # forward g = g.to(device) with torch.no_grad(): reps.append(net(g)) labels.append(g.ndata['label']) reps = torch.cat(reps, dim=0) labels = torch.cat(labels, dim=0) return [reps, labels]