"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "88972172d87a776a34373a8e2d66791abf61e222"
Unverified Commit f8811c7d authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix some minor problems in GAT (#308)

* fix gat.

* fix context.
parent dba36c87
...@@ -146,13 +146,12 @@ def main(args): ...@@ -146,13 +146,12 @@ def main(args):
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
if args.gpu < 0: if args.gpu < 0:
cuda = False ctx = mx.cpu(0)
else: else:
cuda = True ctx = mx.gpu(args.gpu)
torch.cuda.set_device(args.gpu) features = features.as_in_context(ctx)
features = features.cuda() labels = labels.as_in_context(ctx)
labels = labels.cuda() mask = mask.as_in_context(ctx)
mask = mask.cuda()
# create GCN model # create GCN model
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
...@@ -169,9 +168,7 @@ def main(args): ...@@ -169,9 +168,7 @@ def main(args):
args.attn_drop, args.attn_drop,
args.residual) args.residual)
if cuda: model.initialize(ctx=ctx)
model.cuda()
model.initialize()
# use optimizer # use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr}) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
...@@ -189,6 +186,7 @@ def main(args): ...@@ -189,6 +186,7 @@ def main(args):
#optimizer.zero_grad() #optimizer.zero_grad()
loss.backward() loss.backward()
trainer.step(features.shape[0]) trainer.step(features.shape[0])
loss.wait_to_read()
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment