Unverified Commit 7ea777e1 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix HAN (#1790)

* Fix

* Fix

* Fix
parent 8a183e3f
......@@ -29,6 +29,11 @@ def main(args):
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
val_mask, test_mask = load_data(args['dataset'])
if hasattr(torch, 'BoolTensor'):
train_mask = train_mask.bool()
val_mask = val_mask.bool()
test_mask = test_mask.bool()
features = features.to(args['device'])
labels = labels.to(args['device'])
train_mask = train_mask.to(args['device'])
......
......@@ -15,10 +15,11 @@ class SemanticAttention(nn.Module):
)
def forward(self, z):
w = self.project(z)
beta = torch.softmax(w, dim=1)
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1)
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module):
"""
......
......@@ -25,10 +25,11 @@ class SemanticAttention(nn.Module):
)
def forward(self, z):
w = self.project(z)
beta = torch.softmax(w, dim=1)
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1)
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment