"src/array/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f2c80b440e80226441dc6c11a95ade10defaaf11"
Unverified Commit 7ea777e1 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix HAN (#1790)

* Fix

* Fix

* Fix
parent 8a183e3f
...@@ -29,6 +29,11 @@ def main(args): ...@@ -29,6 +29,11 @@ def main(args):
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \ g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
val_mask, test_mask = load_data(args['dataset']) val_mask, test_mask = load_data(args['dataset'])
if hasattr(torch, 'BoolTensor'):
train_mask = train_mask.bool()
val_mask = val_mask.bool()
test_mask = test_mask.bool()
features = features.to(args['device']) features = features.to(args['device'])
labels = labels.to(args['device']) labels = labels.to(args['device'])
train_mask = train_mask.to(args['device']) train_mask = train_mask.to(args['device'])
......
...@@ -15,10 +15,11 @@ class SemanticAttention(nn.Module): ...@@ -15,10 +15,11 @@ class SemanticAttention(nn.Module):
) )
def forward(self, z): def forward(self, z):
w = self.project(z) w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=1) beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module): class HANLayer(nn.Module):
""" """
......
...@@ -25,10 +25,11 @@ class SemanticAttention(nn.Module): ...@@ -25,10 +25,11 @@ class SemanticAttention(nn.Module):
) )
def forward(self, z): def forward(self, z):
w = self.project(z) w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=1) beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module): class HANLayer(nn.Module):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment