"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "121b9db01d9ccdcd5f32586cb512d0e765dbccac"
Unverified Commit f8811c7d authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix some minor problems in GAT (#308)

* fix gat.

* fix context.
parent dba36c87
......@@ -146,13 +146,12 @@ def main(args):
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
ctx = mx.gpu(args.gpu)
features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
mask = mask.as_in_context(ctx)
# create GCN model
g = DGLGraph(data.graph)
......@@ -169,9 +168,7 @@ def main(args):
args.attn_drop,
args.residual)
if cuda:
model.cuda()
model.initialize()
model.initialize(ctx=ctx)
# use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
......@@ -189,6 +186,7 @@ def main(args):
#optimizer.zero_grad()
loss.backward()
trainer.step(features.shape[0])
loss.wait_to_read()
if epoch >= 3:
dur.append(time.time() - t0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment