Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c87564d5
Commit
c87564d5
authored
Jun 22, 2018
by
Lingfan Yu
Browse files
minor bug fixes
parent
2daba976
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
examples/pytorch/gat.py
examples/pytorch/gat.py
+2
-1
examples/pytorch/gcn.py
examples/pytorch/gcn.py
+7
-4
No files found.
examples/pytorch/gat.py
View file @
c87564d5
...
@@ -126,6 +126,7 @@ def main(args):
...
@@ -126,6 +126,7 @@ def main(args):
# convert labels and masks to tensor
# convert labels and masks to tensor
labels
=
torch
.
FloatTensor
(
y_train
)
labels
=
torch
.
FloatTensor
(
y_train
)
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
n_train
=
torch
.
sum
(
mask
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
# reset grad
# reset grad
...
@@ -141,7 +142,7 @@ def main(args):
...
@@ -141,7 +142,7 @@ def main(args):
# masked cross entropy loss
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
# TODO: (lingfan) use gather to speed up
logp
=
F
.
log_softmax
(
logits
,
1
)
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
torch
.
mean
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
loss
=
-
torch
.
sum
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
/
n_train
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
loss
.
backward
()
loss
.
backward
()
...
...
examples/pytorch/gcn.py
View file @
c87564d5
...
@@ -16,9 +16,10 @@ class NodeUpdateModule(nn.Module):
...
@@ -16,9 +16,10 @@ class NodeUpdateModule(nn.Module):
def
forward
(
self
,
node
,
msgs
):
def
forward
(
self
,
node
,
msgs
):
h
=
node
[
'h'
]
h
=
node
[
'h'
]
# (lingfan): how to write dropout, is the following correct?
if
self
.
p
is
not
None
:
if
self
.
p
is
not
None
:
h
=
F
.
dropout
(
h
,
p
=
self
.
p
)
h
=
F
.
dropout
(
h
,
p
=
self
.
p
)
# aggregat
or
messages
# aggregat
e
messages
for
msg
in
msgs
:
for
msg
in
msgs
:
h
+=
msg
h
+=
msg
h
=
self
.
linear
(
h
)
h
=
self
.
linear
(
h
)
...
@@ -29,7 +30,7 @@ class NodeUpdateModule(nn.Module):
...
@@ -29,7 +30,7 @@ class NodeUpdateModule(nn.Module):
class
GCN
(
nn
.
Module
):
class
GCN
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_classes
,
num_layers
,
activation
,
dropout
):
def
__init__
(
self
,
input_dim
,
num_hidden
,
num_classes
,
num_layers
,
activation
,
dropout
=
None
,
output_projection
=
True
):
super
(
GCN
,
self
).
__init__
()
super
(
GCN
,
self
).
__init__
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
# hidden layers
# hidden layers
...
@@ -39,6 +40,7 @@ class GCN(nn.Module):
...
@@ -39,6 +40,7 @@ class GCN(nn.Module):
NodeUpdateModule
(
last_dim
,
num_hidden
,
act
=
activation
,
p
=
dropout
))
NodeUpdateModule
(
last_dim
,
num_hidden
,
act
=
activation
,
p
=
dropout
))
last_dim
=
num_hidden
last_dim
=
num_hidden
# output layer
# output layer
if
output_projection
:
self
.
layers
.
append
(
NodeUpdateModule
(
num_hidden
,
num_classes
,
p
=
dropout
))
self
.
layers
.
append
(
NodeUpdateModule
(
num_hidden
,
num_classes
,
p
=
dropout
))
def
forward
(
self
,
g
):
def
forward
(
self
,
g
):
...
@@ -72,6 +74,7 @@ def main(args):
...
@@ -72,6 +74,7 @@ def main(args):
# convert labels and masks to tensor
# convert labels and masks to tensor
labels
=
torch
.
FloatTensor
(
y_train
)
labels
=
torch
.
FloatTensor
(
y_train
)
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
mask
=
torch
.
FloatTensor
(
train_mask
.
astype
(
np
.
float32
))
n_train
=
torch
.
sum
(
mask
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
# reset grad
# reset grad
...
@@ -87,7 +90,7 @@ def main(args):
...
@@ -87,7 +90,7 @@ def main(args):
# masked cross entropy loss
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
# TODO: (lingfan) use gather to speed up
logp
=
F
.
log_softmax
(
logits
,
1
)
logp
=
F
.
log_softmax
(
logits
,
1
)
loss
=
torch
.
mean
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
loss
=
-
torch
.
sum
(
logp
*
labels
*
mask
.
view
(
-
1
,
1
))
/
n_train
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
print
(
"epoch {} loss: {}"
.
format
(
epoch
,
loss
.
item
()))
loss
.
backward
()
loss
.
backward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment