graphconv.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

import dgl.function as fn
from dgl.nn.pytorch import GATConv

12

13
14
15
16
17
18
19
20
21
22
class GraphConvLayer(nn.Module):
    def __init__(self, in_feats, out_feats, bias=True):
        super(GraphConvLayer, self).__init__()
        self.mlp = nn.Linear(in_feats * 2, out_feats, bias=bias)

    def forward(self, bipartite, feat):
        if isinstance(feat, tuple):
            srcfeat, dstfeat = feat
        else:
            srcfeat = feat
23
            dstfeat = feat[: bipartite.num_dst_nodes()]
24
25
        graph = bipartite.local_var()

26
27
28
29
        graph.srcdata["h"] = srcfeat
        graph.update_all(
            fn.u_mul_e("h", "affine", "m"), fn.sum(msg="m", out="h")
        )
30

31
        gcn_feat = torch.cat([dstfeat, graph.dstdata["h"]], dim=-1)
32
33
34
        out = self.mlp(gcn_feat)
        return out

35

36
class GraphConv(nn.Module):
37
    def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):
38
39
40
41
42
        super(GraphConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        if use_GAT:
43
44
45
            self.gcn_layer = GATConv(
                in_dim, out_dim, K, allow_zero_in_degree=True
            )
46
47
48
49
50
51
52
53
54
55
56
57
            self.bias = nn.Parameter(torch.Tensor(K, out_dim))
            init.constant_(self.bias, 0)
        else:
            self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)

        self.dropout = dropout
        self.use_GAT = use_GAT

    def forward(self, bipartite, features):
        out = self.gcn_layer(bipartite, features)

        if self.use_GAT:
58
            out = torch.mean(out + self.bias, dim=1)
59
60
61
62
63
64
65

        out = out.reshape(out.shape[0], -1)
        out = F.relu(out)
        if self.dropout > 0:
            out = F.dropout(out, self.dropout, training=self.training)

        return out