Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
60899a36
"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b8cf84a3f902550937255c5b28b39827ba52beb6"
Commit
60899a36
authored
Jun 14, 2018
by
zzhang-cn
Browse files
pytorch foldre org
parent
2ed3989c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
140 deletions
+83
-140
graph.py
graph.py
+2
-140
pytorch/util.py
pytorch/util.py
+81
-0
No files found.
graph.py
View file @
60899a36
import
networkx
as
nx
from
networkx.classes.digraph
import
DiGraph
'''
Defult modules: this is Pytorch specific
- MessageModule: copy
- UpdateModule: vanilla RNN
- ReadoutModule: bag of words
- ReductionModule: bag of words
'''
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
DefaultMessageModule
(
nn
.
Module
):
"""
Default message module:
- copy
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultMessageModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
):
return
x
class
DefaultUpdateModule
(
nn
.
Module
):
"""
Default update module:
- a vanilla GRU with ReLU, or GRU
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultUpdateModule
,
self
).
__init__
()
h_dims
=
self
.
h_dims
=
kwargs
.
get
(
'h_dims'
,
128
)
net_type
=
self
.
net_type
=
kwargs
.
get
(
'net_type'
,
'fwd'
)
n_func
=
self
.
n_func
=
kwargs
.
get
(
'n_func'
,
1
)
self
.
f_idx
=
0
self
.
reduce_func
=
DefaultReductionModule
()
if
net_type
==
'gru'
:
self
.
net
=
[
nn
.
GRUCell
(
h_dims
,
h_dims
)
for
i
in
range
(
n_func
)]
else
:
self
.
net
=
[
nn
.
Linear
(
2
*
h_dims
,
h_dims
)
for
i
in
range
(
n_func
)]
def
forward
(
self
,
x
,
msgs
):
if
not
th
.
is_tensor
(
x
):
x
=
th
.
zeros_like
(
msgs
[
0
])
m
=
self
.
reduce_func
(
msgs
)
assert
(
self
.
f_idx
<
self
.
n_func
)
if
self
.
net_type
==
'gru'
:
out
=
self
.
net
[
self
.
f_idx
](
m
,
x
)
else
:
_in
=
th
.
cat
((
m
,
x
),
1
)
out
=
F
.
relu
(
self
.
net
[
self
.
f_idx
](
_in
))
self
.
f_idx
+=
1
return
out
def
reset_f_idx
(
self
):
self
.
f_idx
=
0
class
DefaultReductionModule
(
nn
.
Module
):
"""
Default readout:
- bag of words
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultReductionModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x_s
):
out
=
th
.
stack
(
x_s
)
out
=
th
.
sum
(
out
,
dim
=
0
)
return
out
class
DefaultReadoutModule
(
nn
.
Module
):
"""
Default readout:
- bag of words
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultReadoutModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
reduce_func
=
DefaultReductionModule
()
def
forward
(
self
,
x_s
):
return
self
.
reduce_func
(
x_s
)
class
mx_Graph
(
DiGraph
):
class
dgl_Graph
(
DiGraph
):
'''
Functions:
- m_func: per edge (u, v), default is u['state']
- u_func: per node u, default is RNN(m, u['state'])
'''
def
__init__
(
self
,
*
args
,
**
kargs
):
super
(
mx
_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
super
(
dgl
_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
self
.
m_func
=
DefaultMessageModule
()
self
.
u_func
=
DefaultUpdateModule
()
self
.
readout_func
=
DefaultReadoutModule
()
...
...
@@ -115,11 +34,6 @@ class mx_Graph(DiGraph):
return
self
.
edges
()
if
edges
==
'all'
else
edges
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
message function: accepts source state tensor and edge tag tensor, and
returns a message tensor
'''
if
edges
==
'all'
:
self
.
m_func
=
message_func
else
:
...
...
@@ -127,11 +41,6 @@ class mx_Graph(DiGraph):
self
.
edges
[
e
][
'm_func'
]
=
message_func
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a list of tuples (source node, target node, edge attribute dictionary)
'''
if
nodes
==
'all'
:
self
.
u_func
=
update_func
else
:
...
...
@@ -210,54 +119,7 @@ class mx_Graph(DiGraph):
pos
=
graphviz_layout
(
self
,
prog
=
'dot'
)
nx
.
draw
(
self
,
pos
,
with_labels
=
True
)
def
set_reduction_func
(
self
):
def
_default_reduction_func
(
x_s
):
out
=
th
.
stack
(
x_s
)
out
=
th
.
sum
(
out
,
dim
=
0
)
return
out
self
.
_reduction_func
=
_default_reduction_func
def
set_gather_func
(
self
,
u
=
None
):
pass
def
print_all
(
self
):
for
n
in
self
.
nodes
:
print
(
n
,
self
.
nodes
[
n
])
print
()
if
__name__
==
'__main__'
:
from
torch.autograd
import
Variable
as
Var
th
.
random
.
manual_seed
(
0
)
print
(
"testing vanilla RNN update"
)
g_path
=
mx_Graph
(
nx
.
path_graph
(
2
))
g_path
.
set_repr
(
0
,
th
.
rand
(
2
,
128
))
g_path
.
sendto
(
0
,
1
)
g_path
.
recvfrom
(
1
,
[
0
])
g_path
.
readout
()
'''
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print
(
"testing GRU update"
)
g
=
mx_Graph
(
nx
.
path_graph
(
3
))
update_net
=
DefaultUpdateModule
(
h_dims
=
4
,
net_type
=
'gru'
)
g
.
register_update_func
(
update_net
)
msg_net
=
nn
.
Sequential
(
nn
.
Linear
(
4
,
4
),
nn
.
ReLU
())
g
.
register_message_func
(
msg_net
)
for
n
in
g
:
g
.
set_repr
(
n
,
th
.
rand
(
2
,
4
))
y_pre
=
g
.
readout
()
g
.
update_from
(
0
)
y_after
=
g
.
readout
()
upd_nets
=
DefaultUpdateModule
(
h_dims
=
4
,
net_type
=
'gru'
,
n_func
=
2
)
g
.
register_update_func
(
upd_nets
)
g
.
update_from
(
0
)
g
.
update_from
(
0
)
pytorch/util.py
View file @
60899a36
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
'''
Defult modules: this is Pytorch specific
- MessageModule: copy
- UpdateModule: vanilla RNN
- ReadoutModule: bag of words
- ReductionModule: bag of words
'''
class
DefaultMessageModule
(
nn
.
Module
):
"""
Default message module:
- copy
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultMessageModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
):
return
x
class
DefaultUpdateModule
(
nn
.
Module
):
"""
Default update module:
- a vanilla GRU with ReLU, or GRU
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultUpdateModule
,
self
).
__init__
()
h_dims
=
self
.
h_dims
=
kwargs
.
get
(
'h_dims'
,
128
)
net_type
=
self
.
net_type
=
kwargs
.
get
(
'net_type'
,
'fwd'
)
n_func
=
self
.
n_func
=
kwargs
.
get
(
'n_func'
,
1
)
self
.
f_idx
=
0
self
.
reduce_func
=
DefaultReductionModule
()
if
net_type
==
'gru'
:
self
.
net
=
[
nn
.
GRUCell
(
h_dims
,
h_dims
)
for
i
in
range
(
n_func
)]
else
:
self
.
net
=
[
nn
.
Linear
(
2
*
h_dims
,
h_dims
)
for
i
in
range
(
n_func
)]
def
forward
(
self
,
x
,
msgs
):
if
not
th
.
is_tensor
(
x
):
x
=
th
.
zeros_like
(
msgs
[
0
])
m
=
self
.
reduce_func
(
msgs
)
assert
(
self
.
f_idx
<
self
.
n_func
)
if
self
.
net_type
==
'gru'
:
out
=
self
.
net
[
self
.
f_idx
](
m
,
x
)
else
:
_in
=
th
.
cat
((
m
,
x
),
1
)
out
=
F
.
relu
(
self
.
net
[
self
.
f_idx
](
_in
))
self
.
f_idx
+=
1
return
out
def
reset_f_idx
(
self
):
self
.
f_idx
=
0
class
DefaultReductionModule
(
nn
.
Module
):
"""
Default readout:
- bag of words
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultReductionModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x_s
):
out
=
th
.
stack
(
x_s
)
out
=
th
.
sum
(
out
,
dim
=
0
)
return
out
class
DefaultReadoutModule
(
nn
.
Module
):
"""
Default readout:
- bag of words
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DefaultReadoutModule
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
reduce_func
=
DefaultReductionModule
()
def
forward
(
self
,
x_s
):
return
self
.
reduce_func
(
x_s
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment