Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f8811c7d
Unverified
Commit
f8811c7d
authored
Dec 16, 2018
by
Da Zheng
Committed by
GitHub
Dec 16, 2018
Browse files
[BUGFIX] fix some minor problems in GAT (#308)
* fix gat. * fix context.
parent
dba36c87
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
9 deletions
+7
-9
examples/mxnet/gat/gat_batch.py
examples/mxnet/gat/gat_batch.py
+7
-9
No files found.
examples/mxnet/gat/gat_batch.py
View file @
f8811c7d
...
@@ -146,13 +146,12 @@ def main(args):
...
@@ -146,13 +146,12 @@ def main(args):
n_edges
=
data
.
graph
.
number_of_edges
()
n_edges
=
data
.
graph
.
number_of_edges
()
if
args
.
gpu
<
0
:
if
args
.
gpu
<
0
:
c
uda
=
False
c
tx
=
mx
.
cpu
(
0
)
else
:
else
:
cuda
=
True
ctx
=
mx
.
gpu
(
args
.
gpu
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
features
=
features
.
as_in_context
(
ctx
)
features
=
features
.
cuda
()
labels
=
labels
.
as_in_context
(
ctx
)
labels
=
labels
.
cuda
()
mask
=
mask
.
as_in_context
(
ctx
)
mask
=
mask
.
cuda
()
# create GCN model
# create GCN model
g
=
DGLGraph
(
data
.
graph
)
g
=
DGLGraph
(
data
.
graph
)
...
@@ -169,9 +168,7 @@ def main(args):
...
@@ -169,9 +168,7 @@ def main(args):
args
.
attn_drop
,
args
.
attn_drop
,
args
.
residual
)
args
.
residual
)
if
cuda
:
model
.
initialize
(
ctx
=
ctx
)
model
.
cuda
()
model
.
initialize
()
# use optimizer
# use optimizer
trainer
=
gluon
.
Trainer
(
model
.
collect_params
(),
'adam'
,
{
'learning_rate'
:
args
.
lr
})
trainer
=
gluon
.
Trainer
(
model
.
collect_params
(),
'adam'
,
{
'learning_rate'
:
args
.
lr
})
...
@@ -189,6 +186,7 @@ def main(args):
...
@@ -189,6 +186,7 @@ def main(args):
#optimizer.zero_grad()
#optimizer.zero_grad()
loss
.
backward
()
loss
.
backward
()
trainer
.
step
(
features
.
shape
[
0
])
trainer
.
step
(
features
.
shape
[
0
])
loss
.
wait_to_read
()
if
epoch
>=
3
:
if
epoch
>=
3
:
dur
.
append
(
time
.
time
()
-
t0
)
dur
.
append
(
time
.
time
()
-
t0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment