Unverified Commit 91b73823 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Model] update gat (#390)

* update gat: add minus max for softmax

* small fix
parent 6c3dba86
...@@ -39,11 +39,11 @@ class GraphAttention(nn.Module): ...@@ -39,11 +39,11 @@ class GraphAttention(nn.Module):
if feat_drop: if feat_drop:
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
else: else:
self.feat_drop = None self.feat_drop = lambda x : x
if attn_drop: if attn_drop:
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
else: else:
self.attn_drop = None self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
...@@ -60,22 +60,19 @@ class GraphAttention(nn.Module): ...@@ -60,22 +60,19 @@ class GraphAttention(nn.Module):
def forward(self, inputs): def forward(self, inputs):
# prepare # prepare
h = inputs # NxD h = self.feat_drop(inputs) # NxD
if self.feat_drop:
h = self.feat_drop(h)
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) # HxNxD' head_ft = ft.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
if self.feat_drop:
ft = self.feat_drop(ft)
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention # 1. compute edge attention
self.g.apply_edges(self.edge_attention) self.g.apply_edges(self.edge_attention)
# 2. compute two results: one is the node features scaled by the dropped, # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
# unnormalized attention values; another is the normalizer of the attention values. self.edge_softmax()
self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], # 2. compute the aggregated node features scaled by the dropped,
[fn.sum('ft', 'ft'), fn.sum('a', 'z')]) # unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
# 3. apply normalizer # 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual # 4. residual
...@@ -90,10 +87,17 @@ class GraphAttention(nn.Module): ...@@ -90,10 +87,17 @@ class GraphAttention(nn.Module):
def edge_attention(self, edges): def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst # an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow return {'a' : a}
if self.attn_drop:
a_drop = self.attn_drop(a) def edge_softmax(self):
return {'a' : a, 'a_drop' : a_drop} # compute the max
self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max'))
# minus the max and exp
self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])})
# compute dropout
self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])})
# compute normalizer
self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -247,7 +251,7 @@ if __name__ == '__main__': ...@@ -247,7 +251,7 @@ if __name__ == '__main__':
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="which GPU to use. Set -1 to use CPU.") help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=300, parser.add_argument("--epochs", type=int, default=200,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8, parser.add_argument("--num-heads", type=int, default=8,
help="number of hidden attention heads") help="number of hidden attention heads")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment