# This file is based on the CompGCN author's implementation # . # It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation. import torch as th import dgl def com_mult(a, b): r1, i1 = a[..., 0], a[..., 1] r2, i2 = b[..., 0], b[..., 1] return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) def conj(a): a[..., 1] = -a[..., 1] return a def ccorr(a, b): """ Compute circular correlation of two tensors. Parameters ---------- a: Tensor, 1D or 2D b: Tensor, 1D or 2D Notes ----- Input a and b should have the same dimensions. And this operation supports broadcasting. Returns ------- Tensor, having the same dimension as the input a. """ return th.fft.irfftn( th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1) ) # identify in/out edges, compute edge norm for each and store in edata def in_out_norm(graph): src, dst, EID = graph.edges(form="all") graph.edata["norm"] = th.ones(EID.shape[0]).to(graph.device) in_edges_idx = th.nonzero( graph.edata["in_edges_mask"], as_tuple=False ).squeeze() out_edges_idx = th.nonzero( graph.edata["out_edges_mask"], as_tuple=False ).squeeze() for idx in [in_edges_idx, out_edges_idx]: u, v = src[idx], dst[idx] deg = th.zeros(graph.num_nodes()).to(graph.device) n_idx, inverse_index, count = th.unique( v, return_inverse=True, return_counts=True ) deg[n_idx] = count.float() deg_inv = deg.pow(-0.5) # D^{-0.5} deg_inv[deg_inv == float("inf")] = 0 norm = deg_inv[u] * deg_inv[v] graph.edata["norm"][idx] = norm graph.edata["norm"] = graph.edata["norm"].unsqueeze(1) return graph