Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
57f480f5
Commit
57f480f5
authored
Jul 12, 2018
by
Ivan Brugere
Committed by
Minjie Wang
Jul 13, 2018
Browse files
LSTM fix (#28)
Correctly handling lstm model creation
parent
c50b90cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
20 deletions
+29
-20
examples/pytorch/geniepath.py
examples/pytorch/geniepath.py
+29
-20
No files found.
examples/pytorch/geniepath.py
View file @
57f480f5
...
@@ -32,6 +32,7 @@ class NodeReduceModule(nn.Module):
...
@@ -32,6 +32,7 @@ class NodeReduceModule(nn.Module):
self
.
fc
=
nn
.
ModuleList
(
self
.
fc
=
nn
.
ModuleList
(
[
nn
.
Linear
(
input_dim
,
num_hidden
,
bias
=
False
)
[
nn
.
Linear
(
input_dim
,
num_hidden
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
for
_
in
range
(
num_heads
)])
self
.
attention
=
nn
.
ModuleList
(
self
.
attention
=
nn
.
ModuleList
(
[
nn
.
Linear
(
num_hidden
*
2
,
1
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
[
nn
.
Linear
(
num_hidden
*
2
,
1
,
bias
=
False
)
for
_
in
range
(
num_heads
)])
...
@@ -61,14 +62,20 @@ class NodeReduceModule(nn.Module):
...
@@ -61,14 +62,20 @@ class NodeReduceModule(nn.Module):
class
NodeUpdateModule
(
nn
.
Module
):
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
residual
,
fc
,
act
,
aggregator
):
def
__init__
(
self
,
residual
,
fc
,
act
,
aggregator
,
lstm_size
=
0
):
super
(
NodeUpdateModule
,
self
).
__init__
()
super
(
NodeUpdateModule
,
self
).
__init__
()
self
.
residual
=
residual
self
.
residual
=
residual
self
.
fc
=
fc
self
.
fc
=
fc
self
.
act
=
act
self
.
act
=
act
self
.
aggregator
=
aggregator
self
.
aggregator
=
aggregator
if
lstm_size
:
self
.
lstm
=
nn
.
LSTM
(
input_size
=
lstm_size
,
hidden_size
=
lstm_size
,
num_layers
=
1
)
else
:
self
.
lstm
=
None
#print(fc[0].out_features)
def
forward
(
self
,
node
,
msgs_repr
):
def
forward
(
self
,
node
,
msgs_repr
):
# apply residual connection and activation for each head
# apply residual connection and activation for each head
for
i
in
range
(
len
(
msgs_repr
)):
for
i
in
range
(
len
(
msgs_repr
)):
...
@@ -80,26 +87,28 @@ class NodeUpdateModule(nn.Module):
...
@@ -80,26 +87,28 @@ class NodeUpdateModule(nn.Module):
# aggregate multi-head results
# aggregate multi-head results
h
=
self
.
aggregator
(
msgs_repr
)
h
=
self
.
aggregator
(
msgs_repr
)
c0
=
torch
.
zeros
(
h
.
shape
)
#print
(h.shape)
if
node
[
'c'
]
is
None
:
if
self
.
lstm
is
not
None
:
c0
=
torch
.
zeros
(
h
.
shape
)
c0
=
torch
.
zeros
(
h
.
shape
)
if
node
[
'c'
]
is
None
:
c0
=
torch
.
zeros
(
h
.
shape
)
else
:
c0
=
node
[
'c'
]
if
node
[
'h_i'
]
is
None
:
h0
=
torch
.
zeros
(
h
.
shape
)
else
:
h0
=
node
[
'h_i'
]
#add dimension to handle sequential (create sequence of length 1)
h
,
(
h_i
,
c
)
=
self
.
lstm
(
h
.
unsqueeze
(
0
),
(
h0
.
unsqueeze
(
0
),
c0
.
unsqueeze
(
0
)))
#remove sequential dim
h
=
torch
.
squeeze
(
h
,
0
)
h_i
=
torch
.
squeeze
(
h
,
0
)
c
=
torch
.
squeeze
(
c
,
0
)
return
{
'h'
:
h
,
'c'
:
c
,
'h_i'
:
h_i
}
else
:
else
:
c0
=
node
[
'c'
]
return
{
'h'
:
h
,
'c'
:
None
,
'h_i'
:
None
}
if
node
[
'h_i'
]
is
None
:
h0
=
torch
.
zeros
(
h
.
shape
)
else
:
h0
=
node
[
'h_i'
]
lstm
=
nn
.
LSTM
(
input_size
=
h
.
shape
[
1
],
hidden_size
=
h
.
shape
[
1
],
num_layers
=
1
)
#add dimension to handle sequential (create sequence of length 1)
h
,
(
h_i
,
c
)
=
lstm
(
h
.
unsqueeze
(
0
),
(
h0
.
unsqueeze
(
0
),
c0
.
unsqueeze
(
0
)))
#remove sequential dim
h
=
torch
.
squeeze
(
h
,
0
)
h_i
=
torch
.
squeeze
(
h
,
0
)
c
=
torch
.
squeeze
(
c
,
0
)
return
{
'h'
:
h
,
'c'
:
c
,
'h_i'
:
h_i
}
class
GeniePath
(
nn
.
Module
):
class
GeniePath
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
num_heads
,
...
@@ -122,7 +131,7 @@ class GeniePath(nn.Module):
...
@@ -122,7 +131,7 @@ class GeniePath(nn.Module):
attention_dropout
))
attention_dropout
))
self
.
update_layers
.
append
(
self
.
update_layers
.
append
(
NodeUpdateModule
(
residual
,
self
.
reduce_layers
[
-
1
].
fc
,
activation
,
NodeUpdateModule
(
residual
,
self
.
reduce_layers
[
-
1
].
fc
,
activation
,
lambda
x
:
torch
.
cat
(
x
,
1
)))
lambda
x
:
torch
.
cat
(
x
,
1
)
,
num_hidden
*
num_heads
))
# projection
# projection
self
.
reduce_layers
.
append
(
self
.
reduce_layers
.
append
(
NodeReduceModule
(
num_hidden
*
num_heads
,
num_classes
,
1
,
input_dropout
,
NodeReduceModule
(
num_hidden
*
num_heads
,
num_classes
,
1
,
input_dropout
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment