"vscode:/vscode.git/clone" did not exist on "704f8e8ed1a4ab992ac626bc91cd62e4909faa8f"
model.py 7.3 KB
Newer Older
Ziniu Hu's avatar
Ziniu Hu committed
1
2
3
4
5
6
import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
7
from dgl.nn.functional import edge_softmax
Ziniu Hu's avatar
Ziniu Hu committed
8
9

class HGTLayer(nn.Module):
10
11
12
13
14
15
16
17
    def __init__(self,
                 in_dim,
                 out_dim,
                 node_dict,
                 edge_dict,
                 n_heads,
                 dropout = 0.2,
                 use_norm = False):
Ziniu Hu's avatar
Ziniu Hu committed
18
19
20
21
        super(HGTLayer, self).__init__()

        self.in_dim        = in_dim
        self.out_dim       = out_dim
22
23
24
25
26
        self.node_dict     = node_dict
        self.edge_dict     = edge_dict
        self.num_types     = len(node_dict)
        self.num_relations = len(edge_dict)
        self.total_rel     = self.num_types * self.num_relations * self.num_types
Ziniu Hu's avatar
Ziniu Hu committed
27
28
29
30
        self.n_heads       = n_heads
        self.d_k           = out_dim // n_heads
        self.sqrt_dk       = math.sqrt(self.d_k)
        self.att           = None
31

Ziniu Hu's avatar
Ziniu Hu committed
32
33
34
35
36
37
        self.k_linears   = nn.ModuleList()
        self.q_linears   = nn.ModuleList()
        self.v_linears   = nn.ModuleList()
        self.a_linears   = nn.ModuleList()
        self.norms       = nn.ModuleList()
        self.use_norm    = use_norm
38

39
        for t in range(self.num_types):
Ziniu Hu's avatar
Ziniu Hu committed
40
41
42
43
44
45
            self.k_linears.append(nn.Linear(in_dim,   out_dim))
            self.q_linears.append(nn.Linear(in_dim,   out_dim))
            self.v_linears.append(nn.Linear(in_dim,   out_dim))
            self.a_linears.append(nn.Linear(out_dim,  out_dim))
            if use_norm:
                self.norms.append(nn.LayerNorm(out_dim))
46

47
48
49
50
        self.relation_pri   = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
        self.relation_att   = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
        self.relation_msg   = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
        self.skip           = nn.Parameter(torch.ones(self.num_types))
Ziniu Hu's avatar
Ziniu Hu committed
51
        self.drop           = nn.Dropout(dropout)
52

Ziniu Hu's avatar
Ziniu Hu committed
53
54
55
        nn.init.xavier_uniform_(self.relation_att)
        nn.init.xavier_uniform_(self.relation_msg)

56
57
58
59
    def forward(self, G, h):
        with G.local_scope():
            node_dict, edge_dict = self.node_dict, self.edge_dict
            for srctype, etype, dsttype in G.canonical_etypes:
60
61
                sub_graph = G[srctype, etype, dsttype]

62
                k_linear = self.k_linears[node_dict[srctype]]
63
                v_linear = self.v_linears[node_dict[srctype]]
64
                q_linear = self.q_linears[node_dict[dsttype]]
65
66
67
68
69
70
71
72
73
74
75

                k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

                e_id = self.edge_dict[etype]

                relation_att = self.relation_att[e_id]
                relation_pri = self.relation_pri[e_id]
                relation_msg = self.relation_msg[e_id]

76
                k = torch.einsum("bij,ijk->bik", k, relation_att)
77
                v = torch.einsum("bij,ijk->bik", v, relation_msg)
78
79
80

                sub_graph.srcdata['k'] = k
                sub_graph.dstdata['q'] = q
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
81
                sub_graph.srcdata['v_%d' % e_id] = v
82
83
84
85
86
87
88

                sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
                attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
                attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')

                sub_graph.edata['t'] = attn_score.unsqueeze(-1)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
89
90
            G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \
                                for etype, e_id in edge_dict.items()}, cross_reducer = 'mean')
91

92
93
94
95
96
97
98
99
            new_h = {}
            for ntype in G.ntypes:
                '''
                    Step 3: Target-specific Aggregation
                    x = norm( W[node_type] * gelu( Agg(x) ) + x )
                '''
                n_id = node_dict[ntype]
                alpha = torch.sigmoid(self.skip[n_id])
100
101
                t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
                trans_out = self.drop(self.a_linears[n_id](t))
102
103
104
105
106
107
                trans_out = trans_out * alpha + h[ntype] * (1-alpha)
                if self.use_norm:
                    new_h[ntype] = self.norms[n_id](trans_out)
                else:
                    new_h[ntype] = trans_out
            return new_h
108

Ziniu Hu's avatar
Ziniu Hu committed
109
class HGT(nn.Module):
110
    def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
Ziniu Hu's avatar
Ziniu Hu committed
111
        super(HGT, self).__init__()
112
113
        self.node_dict = node_dict
        self.edge_dict = edge_dict
Ziniu Hu's avatar
Ziniu Hu committed
114
115
116
117
118
119
        self.gcs = nn.ModuleList()
        self.n_inp = n_inp
        self.n_hid = n_hid
        self.n_out = n_out
        self.n_layers = n_layers
        self.adapt_ws  = nn.ModuleList()
120
        for t in range(len(node_dict)):
Ziniu Hu's avatar
Ziniu Hu committed
121
122
            self.adapt_ws.append(nn.Linear(n_inp,   n_hid))
        for _ in range(n_layers):
123
            self.gcs.append(HGTLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm))
Ziniu Hu's avatar
Ziniu Hu committed
124
125
126
        self.out = nn.Linear(n_hid, n_out)

    def forward(self, G, out_key):
127
        h = {}
Ziniu Hu's avatar
Ziniu Hu committed
128
        for ntype in G.ntypes:
129
130
            n_id = self.node_dict[ntype]
            h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
Ziniu Hu's avatar
Ziniu Hu committed
131
        for i in range(self.n_layers):
132
133
            h = self.gcs[i](G, h)
        return self.out(h[out_key])
Ziniu Hu's avatar
Ziniu Hu committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

class HeteroRGCNLayer(nn.Module):
    def __init__(self, in_size, out_size, etypes):
        super(HeteroRGCNLayer, self).__init__()
        # W_r for each relation
        self.weight = nn.ModuleDict({
                name : nn.Linear(in_size, out_size) for name in etypes
            })

    def forward(self, G, feat_dict):
        # The input is a dictionary of node features for each type
        funcs = {}
        for srctype, etype, dsttype in G.canonical_etypes:
            # Compute W_r * h
            Wh = self.weight[etype](feat_dict[srctype])
            # Save it in graph for message passing
            G.nodes[srctype].data['Wh_%s' % etype] = Wh
            # Specify per-relation message passing functions: (message_func, reduce_func).
            # Note that the results are saved to the same destination feature 'h', which
            # hints the type wise reducer for aggregation.
            funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
        # Trigger message passing of multiple types.
        # The first argument is the message passing functions for each relation.
        # The second one is the type wise reducer, could be "sum", "max",
        # "min", "mean", "stack"
        G.multi_update_all(funcs, 'sum')
        # return the updated node feature dictionary
        return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
162
163


Ziniu Hu's avatar
Ziniu Hu committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class HeteroRGCN(nn.Module):
    def __init__(self, G, in_size, hidden_size, out_size):
        super(HeteroRGCN, self).__init__()
        # create layers
        self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes)
        self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)

    def forward(self, G, out_key):
        input_dict = {ntype : G.nodes[ntype].data['inp'] for ntype in G.ntypes}
        h_dict = self.layer1(G, input_dict)
        h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
        h_dict = self.layer2(G, h_dict)
        # get paper logits
        return h_dict[out_key]