utils.py 1.86 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
# This file is based on the CompGCN author's implementation
# <https://github.com/malllabiisc/CompGCN/blob/master/helper.py>.
# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.

import torch as th
6

KounianhuaDu's avatar
KounianhuaDu committed
7
8
import dgl

9

KounianhuaDu's avatar
KounianhuaDu committed
10
def com_mult(a, b):
11
12
13
    r1, i1 = a[..., 0], a[..., 1]
    r2, i2 = b[..., 0], b[..., 1]
    return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
KounianhuaDu's avatar
KounianhuaDu committed
14
15
16


def conj(a):
17
18
    a[..., 1] = -a[..., 1]
    return a
KounianhuaDu's avatar
KounianhuaDu committed
19
20
21


def ccorr(a, b):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    """
    Compute circular correlation of two tensors.
    Parameters
    ----------
    a: Tensor, 1D or 2D
    b: Tensor, 1D or 2D

    Notes
    -----
    Input a and b should have the same dimensions. And this operation supports broadcasting.

    Returns
    -------
    Tensor, having the same dimension as the input a.
    """
    return th.fft.irfftn(
        th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)
    )


# identify in/out edges, compute edge norm for each and store in edata
KounianhuaDu's avatar
KounianhuaDu committed
43
def in_out_norm(graph):
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    src, dst, EID = graph.edges(form="all")
    graph.edata["norm"] = th.ones(EID.shape[0]).to(graph.device)

    in_edges_idx = th.nonzero(
        graph.edata["in_edges_mask"], as_tuple=False
    ).squeeze()
    out_edges_idx = th.nonzero(
        graph.edata["out_edges_mask"], as_tuple=False
    ).squeeze()

    for idx in [in_edges_idx, out_edges_idx]:
        u, v = src[idx], dst[idx]
        deg = th.zeros(graph.num_nodes()).to(graph.device)
        n_idx, inverse_index, count = th.unique(
            v, return_inverse=True, return_counts=True
        )
        deg[n_idx] = count.float()
        deg_inv = deg.pow(-0.5)  # D^{-0.5}
        deg_inv[deg_inv == float("inf")] = 0
        norm = deg_inv[u] * deg_inv[v]
        graph.edata["norm"][idx] = norm
    graph.edata["norm"] = graph.edata["norm"].unsqueeze(1)

    return graph