model.py 3.54 KB
Newer Older
1
2
import torch as th
import torch.nn as nn
3
from utils import local_global_loss_
4
5
6
7
8
9
10
11
12
13
14
15
16
17

from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.glob import SumPooling


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(MLP, self).__init__()
        self.fcs = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.PReLU(),
            nn.Linear(out_dim, out_dim),
            nn.PReLU(),
            nn.Linear(out_dim, out_dim),
18
            nn.PReLU(),
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        )
        self.linear_shortcut = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.fcs(x) + self.linear_shortcut(x)


class GCN(nn.Module):
    def __init__(self, in_dim, out_dim, num_layers, norm):
        super(GCN, self).__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList()

33
34
35
36
37
        self.layers.append(
            GraphConv(
                in_dim, out_dim, bias=False, norm=norm, activation=nn.PReLU()
            )
        )
38
39
40
        self.pooling = SumPooling()

        for _ in range(num_layers - 1):
41
42
43
44
45
46
47
48
49
50
51
            self.layers.append(
                GraphConv(
                    out_dim,
                    out_dim,
                    bias=False,
                    norm=norm,
                    activation=nn.PReLU(),
                )
            )

    def forward(self, graph, feat, edge_weight=None):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        h = self.layers[0](graph, feat, edge_weight=edge_weight)
        hg = self.pooling(graph, h)

        for idx in range(self.num_layers - 1):
            h = self.layers[idx + 1](graph, h, edge_weight=edge_weight)
            hg = th.cat((hg, self.pooling(graph, h)), -1)

        return h, hg


class MVGRL(nn.Module):
    r"""
        mvgrl model
    Parameters
    -----------
    in_dim: int
        Input feature size.
    out_dim: int
        Output feature size.
    num_layers: int
        Number of the GNN encoder layers.
    Functions
    -----------
    forward(graph1, graph2, feat, edge_weight):
        graph1: DGLGraph
            The original graph
        graph2: DGLGraph
            The diffusion graph
        feat: tensor
            Node features
        edge_weight: tensor
            Edge weight of the diffusion graph
    """
85

86
87
88
89
    def __init__(self, in_dim, out_dim, num_layers):
        super(MVGRL, self).__init__()
        self.local_mlp = MLP(out_dim, out_dim)
        self.global_mlp = MLP(num_layers * out_dim, out_dim)
90
91
        self.encoder1 = GCN(in_dim, out_dim, num_layers, norm="both")
        self.encoder2 = GCN(in_dim, out_dim, num_layers, norm="none")
92
93
94

    def get_embedding(self, graph1, graph2, feat, edge_weight):
        local_v1, global_v1 = self.encoder1(graph1, feat)
95
96
97
        local_v2, global_v2 = self.encoder2(
            graph2, feat, edge_weight=edge_weight
        )
98
99
100
101
102
103
104
105
106

        global_v1 = self.global_mlp(global_v1)
        global_v2 = self.global_mlp(global_v2)

        return (global_v1 + global_v2).detach()

    def forward(self, graph1, graph2, feat, edge_weight, graph_id):
        # calculate node embeddings and graph embeddings
        local_v1, global_v1 = self.encoder1(graph1, feat)
107
108
109
        local_v2, global_v2 = self.encoder2(
            graph2, feat, edge_weight=edge_weight
        )
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        local_v1 = self.local_mlp(local_v1)
        local_v2 = self.local_mlp(local_v2)

        global_v1 = self.global_mlp(global_v1)
        global_v2 = self.global_mlp(global_v2)

        # calculate loss
        loss1 = local_global_loss_(local_v1, global_v2, graph_id)
        loss2 = local_global_loss_(local_v2, global_v1, graph_id)

        loss = loss1 + loss2

        return loss