model.py 7.65 KB
Newer Older
Ziniu Hu's avatar
Ziniu Hu committed
1
import math
2

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
5
import dgl
import dgl.function as fn

Ziniu Hu's avatar
Ziniu Hu committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
from dgl.nn.functional import edge_softmax
Ziniu Hu's avatar
Ziniu Hu committed
10

11

Ziniu Hu's avatar
Ziniu Hu committed
12
class HGTLayer(nn.Module):
13
14
15
16
17
18
19
20
21
22
    def __init__(
        self,
        in_dim,
        out_dim,
        node_dict,
        edge_dict,
        n_heads,
        dropout=0.2,
        use_norm=False,
    ):
Ziniu Hu's avatar
Ziniu Hu committed
23
24
        super(HGTLayer, self).__init__()

25
26
27
28
29
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.node_dict = node_dict
        self.edge_dict = edge_dict
        self.num_types = len(node_dict)
30
        self.num_relations = len(edge_dict)
31
32
33
34
35
36
37
38
39
40
41
42
        self.total_rel = self.num_types * self.num_relations * self.num_types
        self.n_heads = n_heads
        self.d_k = out_dim // n_heads
        self.sqrt_dk = math.sqrt(self.d_k)
        self.att = None

        self.k_linears = nn.ModuleList()
        self.q_linears = nn.ModuleList()
        self.v_linears = nn.ModuleList()
        self.a_linears = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.use_norm = use_norm
43

44
        for t in range(self.num_types):
45
46
47
48
            self.k_linears.append(nn.Linear(in_dim, out_dim))
            self.q_linears.append(nn.Linear(in_dim, out_dim))
            self.v_linears.append(nn.Linear(in_dim, out_dim))
            self.a_linears.append(nn.Linear(out_dim, out_dim))
Ziniu Hu's avatar
Ziniu Hu committed
49
50
            if use_norm:
                self.norms.append(nn.LayerNorm(out_dim))
51

52
53
54
55
56
57
58
59
60
61
62
        self.relation_pri = nn.Parameter(
            torch.ones(self.num_relations, self.n_heads)
        )
        self.relation_att = nn.Parameter(
            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
        )
        self.relation_msg = nn.Parameter(
            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
        )
        self.skip = nn.Parameter(torch.ones(self.num_types))
        self.drop = nn.Dropout(dropout)
63

Ziniu Hu's avatar
Ziniu Hu committed
64
65
66
        nn.init.xavier_uniform_(self.relation_att)
        nn.init.xavier_uniform_(self.relation_msg)

67
68
69
70
    def forward(self, G, h):
        with G.local_scope():
            node_dict, edge_dict = self.node_dict, self.edge_dict
            for srctype, etype, dsttype in G.canonical_etypes:
71
72
                sub_graph = G[srctype, etype, dsttype]

73
                k_linear = self.k_linears[node_dict[srctype]]
74
                v_linear = self.v_linears[node_dict[srctype]]
75
                q_linear = self.q_linears[node_dict[dsttype]]
76
77
78
79
80
81
82
83
84
85
86

                k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

                e_id = self.edge_dict[etype]

                relation_att = self.relation_att[e_id]
                relation_pri = self.relation_pri[e_id]
                relation_msg = self.relation_msg[e_id]

87
                k = torch.einsum("bij,ijk->bik", k, relation_att)
88
                v = torch.einsum("bij,ijk->bik", v, relation_msg)
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
                sub_graph.srcdata["k"] = k
                sub_graph.dstdata["q"] = q
                sub_graph.srcdata["v_%d" % e_id] = v

                sub_graph.apply_edges(fn.v_dot_u("q", "k", "t"))
                attn_score = (
                    sub_graph.edata.pop("t").sum(-1)
                    * relation_pri
                    / self.sqrt_dk
                )
                attn_score = edge_softmax(sub_graph, attn_score, norm_by="dst")

                sub_graph.edata["t"] = attn_score.unsqueeze(-1)

            G.multi_update_all(
                {
                    etype: (
                        fn.u_mul_e("v_%d" % e_id, "t", "m"),
                        fn.sum("m", "t"),
                    )
                    for etype, e_id in edge_dict.items()
                },
                cross_reducer="mean",
            )
114

115
116
            new_h = {}
            for ntype in G.ntypes:
117
118
119
120
                """
                Step 3: Target-specific Aggregation
                x = norm( W[node_type] * gelu( Agg(x) ) + x )
                """
121
122
                n_id = node_dict[ntype]
                alpha = torch.sigmoid(self.skip[n_id])
123
                t = G.nodes[ntype].data["t"].view(-1, self.out_dim)
124
                trans_out = self.drop(self.a_linears[n_id](t))
125
                trans_out = trans_out * alpha + h[ntype] * (1 - alpha)
126
127
128
129
130
                if self.use_norm:
                    new_h[ntype] = self.norms[n_id](trans_out)
                else:
                    new_h[ntype] = trans_out
            return new_h
131

132

Ziniu Hu's avatar
Ziniu Hu committed
133
class HGT(nn.Module):
134
135
136
137
138
139
140
141
142
143
144
145
    def __init__(
        self,
        G,
        node_dict,
        edge_dict,
        n_inp,
        n_hid,
        n_out,
        n_layers,
        n_heads,
        use_norm=True,
    ):
Ziniu Hu's avatar
Ziniu Hu committed
146
        super(HGT, self).__init__()
147
148
        self.node_dict = node_dict
        self.edge_dict = edge_dict
Ziniu Hu's avatar
Ziniu Hu committed
149
150
151
152
153
        self.gcs = nn.ModuleList()
        self.n_inp = n_inp
        self.n_hid = n_hid
        self.n_out = n_out
        self.n_layers = n_layers
154
        self.adapt_ws = nn.ModuleList()
155
        for t in range(len(node_dict)):
156
            self.adapt_ws.append(nn.Linear(n_inp, n_hid))
Ziniu Hu's avatar
Ziniu Hu committed
157
        for _ in range(n_layers):
158
159
160
161
162
163
164
165
166
167
            self.gcs.append(
                HGTLayer(
                    n_hid,
                    n_hid,
                    node_dict,
                    edge_dict,
                    n_heads,
                    use_norm=use_norm,
                )
            )
Ziniu Hu's avatar
Ziniu Hu committed
168
169
170
        self.out = nn.Linear(n_hid, n_out)

    def forward(self, G, out_key):
171
        h = {}
Ziniu Hu's avatar
Ziniu Hu committed
172
        for ntype in G.ntypes:
173
            n_id = self.node_dict[ntype]
174
            h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data["inp"]))
Ziniu Hu's avatar
Ziniu Hu committed
175
        for i in range(self.n_layers):
176
177
            h = self.gcs[i](G, h)
        return self.out(h[out_key])
Ziniu Hu's avatar
Ziniu Hu committed
178

179

Ziniu Hu's avatar
Ziniu Hu committed
180
181
182
183
class HeteroRGCNLayer(nn.Module):
    def __init__(self, in_size, out_size, etypes):
        super(HeteroRGCNLayer, self).__init__()
        # W_r for each relation
184
185
186
        self.weight = nn.ModuleDict(
            {name: nn.Linear(in_size, out_size) for name in etypes}
        )
Ziniu Hu's avatar
Ziniu Hu committed
187
188
189
190
191
192
193
194

    def forward(self, G, feat_dict):
        # The input is a dictionary of node features for each type
        funcs = {}
        for srctype, etype, dsttype in G.canonical_etypes:
            # Compute W_r * h
            Wh = self.weight[etype](feat_dict[srctype])
            # Save it in graph for message passing
195
            G.nodes[srctype].data["Wh_%s" % etype] = Wh
Ziniu Hu's avatar
Ziniu Hu committed
196
197
198
            # Specify per-relation message passing functions: (message_func, reduce_func).
            # Note that the results are saved to the same destination feature 'h', which
            # hints the type wise reducer for aggregation.
199
            funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.mean("m", "h"))
Ziniu Hu's avatar
Ziniu Hu committed
200
201
202
203
        # Trigger message passing of multiple types.
        # The first argument is the message passing functions for each relation.
        # The second one is the type wise reducer, could be "sum", "max",
        # "min", "mean", "stack"
204
        G.multi_update_all(funcs, "sum")
Ziniu Hu's avatar
Ziniu Hu committed
205
        # return the updated node feature dictionary
206
        return {ntype: G.nodes[ntype].data["h"] for ntype in G.ntypes}
207
208


Ziniu Hu's avatar
Ziniu Hu committed
209
210
211
212
213
214
215
216
class HeteroRGCN(nn.Module):
    def __init__(self, G, in_size, hidden_size, out_size):
        super(HeteroRGCN, self).__init__()
        # create layers
        self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes)
        self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)

    def forward(self, G, out_key):
217
        input_dict = {ntype: G.nodes[ntype].data["inp"] for ntype in G.ntypes}
Ziniu Hu's avatar
Ziniu Hu committed
218
        h_dict = self.layer1(G, input_dict)
219
        h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
Ziniu Hu's avatar
Ziniu Hu committed
220
        h_dict = self.layer2(G, h_dict)
221
        # get appropriate logits
Ziniu Hu's avatar
Ziniu Hu committed
222
        return h_dict[out_key]