"tools/distpartitioning/gloo_wrapper.py" did not exist on "7c598aac6c25fbee53e52f6bd54c2fd04bad2151"
graphconv.py 1.83 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
import dgl.function as fn
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from torch.nn import init
10

11

12
13
14
15
16
17
18
19
20
21
class GraphConvLayer(nn.Module):
    def __init__(self, in_feats, out_feats, bias=True):
        super(GraphConvLayer, self).__init__()
        self.mlp = nn.Linear(in_feats * 2, out_feats, bias=bias)

    def forward(self, bipartite, feat):
        if isinstance(feat, tuple):
            srcfeat, dstfeat = feat
        else:
            srcfeat = feat
22
            dstfeat = feat[: bipartite.num_dst_nodes()]
23
24
        graph = bipartite.local_var()

25
26
27
28
        graph.srcdata["h"] = srcfeat
        graph.update_all(
            fn.u_mul_e("h", "affine", "m"), fn.sum(msg="m", out="h")
        )
29

30
        gcn_feat = torch.cat([dstfeat, graph.dstdata["h"]], dim=-1)
31
32
33
        out = self.mlp(gcn_feat)
        return out

34

35
class GraphConv(nn.Module):
36
    def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):
37
38
39
40
41
        super(GraphConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        if use_GAT:
42
43
44
            self.gcn_layer = GATConv(
                in_dim, out_dim, K, allow_zero_in_degree=True
            )
45
46
47
48
49
50
51
52
53
54
55
56
            self.bias = nn.Parameter(torch.Tensor(K, out_dim))
            init.constant_(self.bias, 0)
        else:
            self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)

        self.dropout = dropout
        self.use_GAT = use_GAT

    def forward(self, bipartite, features):
        out = self.gcn_layer(bipartite, features)

        if self.use_GAT:
57
            out = torch.mean(out + self.bias, dim=1)
58
59
60
61
62
63
64

        out = out.reshape(out.shape[0], -1)
        out = F.relu(out)
        if self.dropout > 0:
            out = F.dropout(out, self.dropout, training=self.training)

        return out