Unverified Commit f8c9d58a authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[Bug] Hardgat enable gpu training (#2632)



* add multihead in DotGatConv

* Fix spacing issue

* Add Unit test for dotgat

* Modified Unit test for dotgat

* Add transformer like divisor

* Update dotgatconv.py

* Update hgao.py

* Update train.py
Co-authored-by: default avatarChen <chesirui@3c22fbe5458c.ant.amazon.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 23afe911
......@@ -79,7 +79,7 @@ class HardGAO(nn.Module):
# Use edge message passing function to get the weight from src node
graph.apply_edges(fn.copy_u('y','y'))
# Select Top k neighbors
subgraph = select_topk(graph,self.k,'y')
subgraph = select_topk(graph.cpu(),self.k,'y').to(graph.device)
# Sigmoid as information threshold
subgraph.ndata['y'] = torch.sigmoid(subgraph.ndata['y'])
# Using vector matrix elementwise mul for acceleration
......
......@@ -51,7 +51,7 @@ def main(args):
cuda = False
else:
cuda = True
g = g.int().to(args.gpu)
g = g.to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment