Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
635 additions
and
341 deletions
+635
-341
examples/pytorch/sagpool/network.py
examples/pytorch/sagpool/network.py
+2
-3
examples/pytorch/seal/logger.py
examples/pytorch/seal/logger.py
+0
-1
examples/pytorch/seal/main.py
examples/pytorch/seal/main.py
+3
-3
examples/pytorch/seal/sampler.py
examples/pytorch/seal/sampler.py
+4
-7
examples/pytorch/seal/utils.py
examples/pytorch/seal/utils.py
+2
-2
examples/pytorch/sgc/sgc.py
examples/pytorch/sgc/sgc.py
+3
-3
examples/pytorch/sgc/sgc_reddit.py
examples/pytorch/sgc/sgc_reddit.py
+2
-2
examples/pytorch/sign/dataset.py
examples/pytorch/sign/dataset.py
+1
-2
examples/pytorch/sign/sign.py
examples/pytorch/sign/sign.py
+3
-3
examples/pytorch/tagcn/train.py
examples/pytorch/tagcn/train.py
+1
-1
examples/pytorch/transformer/modules/act.py
examples/pytorch/transformer/modules/act.py
+152
-69
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+163
-65
examples/pytorch/transformer/modules/viz.py
examples/pytorch/transformer/modules/viz.py
+235
-130
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+3
-3
examples/pytorch/tree_lstm/tree_lstm.py
examples/pytorch/tree_lstm/tree_lstm.py
+46
-33
examples/pytorch/vgae/model.py
examples/pytorch/vgae/model.py
+1
-1
examples/pytorch/vgae/train.py
examples/pytorch/vgae/train.py
+3
-3
examples/pytorch/vrgcn/train_cv.py
examples/pytorch/vrgcn/train_cv.py
+5
-5
examples/pytorch/vrgcn/train_cv_multi_gpu.py
examples/pytorch/vrgcn/train_cv_multi_gpu.py
+5
-5
examples/sparse/c_and_s.py
examples/sparse/c_and_s.py
+1
-0
No files found.
examples/pytorch/sagpool/network.py
View file @
704bcaf6
import
dgl
import
torch
import
torch
import
torch.nn
import
torch.nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
layer
import
ConvPoolBlock
,
SAGPool
import
dgl
from
dgl.nn
import
AvgPooling
,
GraphConv
,
MaxPooling
from
dgl.nn
import
AvgPooling
,
GraphConv
,
MaxPooling
from
layer
import
ConvPoolBlock
,
SAGPool
class
SAGNetworkHierarchical
(
torch
.
nn
.
Module
):
class
SAGNetworkHierarchical
(
torch
.
nn
.
Module
):
...
...
examples/pytorch/seal/logger.py
View file @
704bcaf6
...
@@ -20,7 +20,6 @@ def _transform_log_level(str_level):
...
@@ -20,7 +20,6 @@ def _transform_log_level(str_level):
class
LightLogging
(
object
):
class
LightLogging
(
object
):
def
__init__
(
self
,
log_path
=
None
,
log_name
=
"lightlog"
,
log_level
=
"debug"
):
def
__init__
(
self
,
log_path
=
None
,
log_name
=
"lightlog"
,
log_level
=
"debug"
):
log_level
=
_transform_log_level
(
log_level
)
log_level
=
_transform_log_level
(
log_level
)
if
log_path
:
if
log_path
:
...
...
examples/pytorch/seal/main.py
View file @
704bcaf6
...
@@ -3,6 +3,9 @@ import time
...
@@ -3,6 +3,9 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.multiprocessing
import
torch.multiprocessing
from
dgl
import
EID
,
NID
from
dgl.dataloading
import
GraphDataLoader
from
logger
import
LightLogging
from
logger
import
LightLogging
from
model
import
DGCNN
,
GCN
from
model
import
DGCNN
,
GCN
from
sampler
import
SEALData
from
sampler
import
SEALData
...
@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss
...
@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
utils
import
evaluate_hits
,
load_ogb_dataset
,
parse_arguments
from
utils
import
evaluate_hits
,
load_ogb_dataset
,
parse_arguments
from
dgl
import
EID
,
NID
from
dgl.dataloading
import
GraphDataLoader
torch
.
multiprocessing
.
set_sharing_strategy
(
"file_system"
)
torch
.
multiprocessing
.
set_sharing_strategy
(
"file_system"
)
"""
"""
...
...
examples/pytorch/seal/sampler.py
View file @
704bcaf6
import
os.path
as
osp
import
os.path
as
osp
from
copy
import
deepcopy
from
copy
import
deepcopy
import
dgl
import
torch
import
torch
from
dgl
import
add_self_loop
,
DGLGraph
,
NID
from
dgl.dataloading.negative_sampler
import
Uniform
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
utils
import
drnl_node_labeling
from
utils
import
drnl_node_labeling
import
dgl
from
dgl
import
NID
,
DGLGraph
,
add_self_loop
from
dgl.dataloading.negative_sampler
import
Uniform
class
GraphDataSet
(
Dataset
):
class
GraphDataSet
(
Dataset
):
"""
"""
...
@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object):
...
@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object):
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
def
__call__
(
self
,
split_type
):
def
__call__
(
self
,
split_type
):
if
split_type
==
"train"
:
if
split_type
==
"train"
:
subsample_ratio
=
self
.
subsample_ratio
subsample_ratio
=
self
.
subsample_ratio
else
:
else
:
...
@@ -177,7 +176,6 @@ class SEALSampler(object):
...
@@ -177,7 +176,6 @@ class SEALSampler(object):
return
subgraph
return
subgraph
def
_collate
(
self
,
batch
):
def
_collate
(
self
,
batch
):
batch_graphs
,
batch_labels
=
map
(
list
,
zip
(
*
batch
))
batch_graphs
,
batch_labels
=
map
(
list
,
zip
(
*
batch
))
batch_graphs
=
dgl
.
batch
(
batch_graphs
)
batch_graphs
=
dgl
.
batch
(
batch_graphs
)
...
@@ -272,7 +270,6 @@ class SEALData(object):
...
@@ -272,7 +270,6 @@ class SEALData(object):
)
)
def
__call__
(
self
,
split_type
):
def
__call__
(
self
,
split_type
):
if
split_type
==
"train"
:
if
split_type
==
"train"
:
subsample_ratio
=
self
.
subsample_ratio
subsample_ratio
=
self
.
subsample_ratio
else
:
else
:
...
...
examples/pytorch/seal/utils.py
View file @
704bcaf6
import
argparse
import
argparse
import
dgl
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
from
ogb.linkproppred
import
DglLinkPropPredDataset
,
Evaluator
from
ogb.linkproppred
import
DglLinkPropPredDataset
,
Evaluator
from
scipy.sparse.csgraph
import
shortest_path
from
scipy.sparse.csgraph
import
shortest_path
import
dgl
def
parse_arguments
():
def
parse_arguments
():
"""
"""
...
...
examples/pytorch/sgc/sgc.py
View file @
704bcaf6
...
@@ -9,13 +9,13 @@ import argparse
...
@@ -9,13 +9,13 @@ import argparse
import
math
import
math
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.data
import
(
from
dgl.data
import
(
CiteseerGraphDataset
,
CiteseerGraphDataset
,
CoraGraphDataset
,
CoraGraphDataset
,
...
...
examples/pytorch/sgc/sgc_reddit.py
View file @
704bcaf6
...
@@ -9,12 +9,12 @@ import argparse
...
@@ -9,12 +9,12 @@ import argparse
import
math
import
math
import
time
import
time
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
load_data
,
register_data_args
from
dgl.data
import
load_data
,
register_data_args
from
dgl.nn.pytorch.conv
import
SGConv
from
dgl.nn.pytorch.conv
import
SGConv
...
...
examples/pytorch/sign/dataset.py
View file @
704bcaf6
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
dgl
def
load_dataset
(
name
):
def
load_dataset
(
name
):
dataset
=
name
.
lower
()
dataset
=
name
.
lower
()
...
...
examples/pytorch/sign/sign.py
View file @
704bcaf6
...
@@ -2,14 +2,14 @@ import argparse
...
@@ -2,14 +2,14 @@ import argparse
import
os
import
os
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dataset
import
load_dataset
from
dataset
import
load_dataset
import
dgl
import
dgl.function
as
fn
class
FeedForwardNet
(
nn
.
Module
):
class
FeedForwardNet
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
hidden
,
out_feats
,
n_layers
,
dropout
):
def
__init__
(
self
,
in_feats
,
hidden
,
out_feats
,
n_layers
,
dropout
):
...
...
examples/pytorch/tagcn/train.py
View file @
704bcaf6
...
@@ -6,10 +6,10 @@ import numpy as np
...
@@ -6,10 +6,10 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
tagcn
import
TAGCN
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
load_data
,
register_data_args
from
dgl.data
import
load_data
,
register_data_args
from
tagcn
import
TAGCN
def
evaluate
(
model
,
features
,
labels
,
mask
):
def
evaluate
(
model
,
features
,
labels
,
mask
):
...
...
examples/pytorch/transformer/modules/act.py
View file @
704bcaf6
...
@@ -2,32 +2,37 @@ from .attention import *
...
@@ -2,32 +2,37 @@ from .attention import *
from
.layers
import
*
from
.layers
import
*
from
.functions
import
*
from
.functions
import
*
from
.embedding
import
*
from
.embedding
import
*
import
torch
as
th
import
dgl.function
as
fn
import
dgl.function
as
fn
import
torch
as
th
import
torch.nn.init
as
INIT
import
torch.nn.init
as
INIT
class
UEncoder
(
nn
.
Module
):
class
UEncoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
):
def
__init__
(
self
,
layer
):
super
(
UEncoder
,
self
).
__init__
()
super
(
UEncoder
,
self
).
__init__
()
self
.
layer
=
layer
self
.
layer
=
layer
self
.
norm
=
LayerNorm
(
layer
.
size
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
fields
=
'
qkv
'
):
def
pre_func
(
self
,
fields
=
"
qkv
"
):
layer
=
self
.
layer
layer
=
self
.
layer
def
func
(
nodes
):
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
func
return
func
def
post_func
(
self
):
def
post_func
(
self
):
layer
=
self
.
layer
layer
=
self
.
layer
def
func
(
nodes
):
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
}
return
{
"x"
:
x
}
return
func
return
func
...
@@ -37,31 +42,36 @@ class UDecoder(nn.Module):
...
@@ -37,31 +42,36 @@ class UDecoder(nn.Module):
self
.
layer
=
layer
self
.
layer
=
layer
self
.
norm
=
LayerNorm
(
layer
.
size
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
fields
=
'
qkv
'
,
l
=
0
):
def
pre_func
(
self
,
fields
=
"
qkv
"
,
l
=
0
):
layer
=
self
.
layer
layer
=
self
.
layer
def
func
(
nodes
):
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
if
fields
==
'
kv
'
:
if
fields
==
"
kv
"
:
norm_x
=
x
norm_x
=
x
else
:
else
:
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
func
return
func
def
post_func
(
self
,
l
=
0
):
def
post_func
(
self
,
l
=
0
):
layer
=
self
.
layer
layer
=
self
.
layer
def
func
(
nodes
):
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
if
l
==
1
:
if
l
==
1
:
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
}
return
{
"x"
:
x
}
return
func
return
func
class
HaltingUnit
(
nn
.
Module
):
class
HaltingUnit
(
nn
.
Module
):
halting_bias_init
=
1.0
halting_bias_init
=
1.0
def
__init__
(
self
,
dim_model
):
def
__init__
(
self
,
dim_model
):
super
(
HaltingUnit
,
self
).
__init__
()
super
(
HaltingUnit
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
dim_model
,
1
)
self
.
linear
=
nn
.
Linear
(
dim_model
,
1
)
...
@@ -71,14 +81,27 @@ class HaltingUnit(nn.Module):
...
@@ -71,14 +81,27 @@ class HaltingUnit(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
th
.
sigmoid
(
self
.
linear
(
self
.
norm
(
x
)))
return
th
.
sigmoid
(
self
.
linear
(
self
.
norm
(
x
)))
class
UTransformer
(
nn
.
Module
):
class
UTransformer
(
nn
.
Module
):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH
=
8
MAX_DEPTH
=
8
thres
=
0.99
thres
=
0.99
act_loss_weight
=
0.01
act_loss_weight
=
0.01
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
d_k
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
d_k
,
):
super
(
UTransformer
,
self
).
__init__
()
super
(
UTransformer
,
self
).
__init__
()
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
self
.
pos_enc
,
self
.
time_enc
=
pos_enc
,
time_enc
self
.
pos_enc
,
self
.
time_enc
=
pos_enc
,
time_enc
self
.
halt_enc
=
HaltingUnit
(
h
*
d_k
)
self
.
halt_enc
=
HaltingUnit
(
h
*
d_k
)
...
@@ -91,34 +114,45 @@ class UTransformer(nn.Module):
...
@@ -91,34 +114,45 @@ class UTransformer(nn.Module):
self
.
stat
=
[
0
]
*
(
self
.
MAX_DEPTH
+
1
)
self
.
stat
=
[
0
]
*
(
self
.
MAX_DEPTH
+
1
)
def
step_forward
(
self
,
nodes
):
def
step_forward
(
self
,
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
step
=
nodes
.
data
[
'step'
]
step
=
nodes
.
data
[
"step"
]
pos
=
nodes
.
data
[
'pos'
]
pos
=
nodes
.
data
[
"pos"
]
return
{
'x'
:
self
.
pos_enc
.
dropout
(
x
+
self
.
pos_enc
(
pos
.
view
(
-
1
))
+
self
.
time_enc
(
step
.
view
(
-
1
))),
return
{
'step'
:
step
+
1
}
"x"
:
self
.
pos_enc
.
dropout
(
x
+
self
.
pos_enc
(
pos
.
view
(
-
1
))
+
self
.
time_enc
(
step
.
view
(
-
1
))
),
"step"
:
step
+
1
,
}
def
halt_and_accum
(
self
,
name
,
end
=
False
):
def
halt_and_accum
(
self
,
name
,
end
=
False
):
"field: 'enc' or 'dec'"
"field: 'enc' or 'dec'"
halt
=
self
.
halt_enc
if
name
==
'
enc
'
else
self
.
halt_dec
halt
=
self
.
halt_enc
if
name
==
"
enc
"
else
self
.
halt_dec
thres
=
self
.
thres
thres
=
self
.
thres
def
func
(
nodes
):
def
func
(
nodes
):
p
=
halt
(
nodes
.
data
[
'x'
])
p
=
halt
(
nodes
.
data
[
"x"
])
sum_p
=
nodes
.
data
[
'
sum_p
'
]
+
p
sum_p
=
nodes
.
data
[
"
sum_p
"
]
+
p
active
=
(
sum_p
<
thres
)
&
(
1
-
end
)
active
=
(
sum_p
<
thres
)
&
(
1
-
end
)
_continue
=
active
.
float
()
_continue
=
active
.
float
()
r
=
nodes
.
data
[
'r'
]
*
(
1
-
_continue
)
+
(
1
-
sum_p
)
*
_continue
r
=
nodes
.
data
[
"r"
]
*
(
1
-
_continue
)
+
(
1
-
sum_p
)
*
_continue
s
=
nodes
.
data
[
's'
]
+
((
1
-
_continue
)
*
r
+
_continue
*
p
)
*
nodes
.
data
[
'x'
]
s
=
(
return
{
'p'
:
p
,
'sum_p'
:
sum_p
,
'r'
:
r
,
's'
:
s
,
'active'
:
active
}
nodes
.
data
[
"s"
]
+
((
1
-
_continue
)
*
r
+
_continue
*
p
)
*
nodes
.
data
[
"x"
]
)
return
{
"p"
:
p
,
"sum_p"
:
sum_p
,
"r"
:
r
,
"s"
:
s
,
"active"
:
active
}
return
func
return
func
def
propagate_attention
(
self
,
g
,
eids
):
def
propagate_attention
(
self
,
g
,
eids
):
# Compute attention score
# Compute attention score
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'
score
'
),
eids
)
g
.
apply_edges
(
src_dot_dst
(
"k"
,
"q"
,
"
score
"
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'
score
'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
scaled_exp
(
"
score
"
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
g
.
send_and_recv
(
[
fn
.
u_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
(
'score'
,
'score'
)],
eids
,
[
fn
.
sum
(
'v'
,
'wv'
),
fn
.
sum
(
'score'
,
'z'
)])
[
fn
.
u_mul_e
(
"v"
,
"score"
,
"v"
),
fn
.
copy_e
(
"score"
,
"score"
)],
[
fn
.
sum
(
"v"
,
"wv"
),
fn
.
sum
(
"score"
,
"z"
)],
)
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
"Update the node states and edge states of the graph."
...
@@ -136,79 +170,128 @@ class UTransformer(nn.Module):
...
@@ -136,79 +170,128 @@ class UTransformer(nn.Module):
nids
,
eids
=
graph
.
nids
,
graph
.
eids
nids
,
eids
=
graph
.
nids
,
graph
.
eids
# embed & pos
# embed & pos
g
.
nodes
[
nids
[
'
enc
'
]].
data
[
'x'
]
=
self
.
src_embed
(
graph
.
src
[
0
])
g
.
nodes
[
nids
[
"
enc
"
]].
data
[
"x"
]
=
self
.
src_embed
(
graph
.
src
[
0
])
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'x'
]
=
self
.
tgt_embed
(
graph
.
tgt
[
0
])
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"x"
]
=
self
.
tgt_embed
(
graph
.
tgt
[
0
])
g
.
nodes
[
nids
[
'
enc
'
]].
data
[
'
pos
'
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
"
enc
"
]].
data
[
"
pos
"
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'
pos
'
]
=
graph
.
tgt
[
1
]
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"
pos
"
]
=
graph
.
tgt
[
1
]
# init step
# init step
device
=
next
(
self
.
parameters
()).
device
device
=
next
(
self
.
parameters
()).
device
g
.
ndata
[
's'
]
=
th
.
zeros
(
N
,
self
.
h
*
self
.
d_k
,
dtype
=
th
.
float
,
device
=
device
)
# accumulated state
g
.
ndata
[
"s"
]
=
th
.
zeros
(
g
.
ndata
[
'p'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# halting prob
N
,
self
.
h
*
self
.
d_k
,
dtype
=
th
.
float
,
device
=
device
g
.
ndata
[
'r'
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# remainder
)
# accumulated state
g
.
ndata
[
'sum_p'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# sum of pondering values
g
.
ndata
[
"p"
]
=
th
.
zeros
(
g
.
ndata
[
'step'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
long
,
device
=
device
)
# step
N
,
1
,
dtype
=
th
.
float
,
device
=
device
g
.
ndata
[
'active'
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
uint8
,
device
=
device
)
# active
)
# halting prob
g
.
ndata
[
"r"
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# remainder
g
.
ndata
[
"sum_p"
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# sum of pondering values
g
.
ndata
[
"step"
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
long
,
device
=
device
)
# step
g
.
ndata
[
"active"
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
uint8
,
device
=
device
)
# active
for
step
in
range
(
self
.
MAX_DEPTH
):
for
step
in
range
(
self
.
MAX_DEPTH
):
pre_func
=
self
.
encoder
.
pre_func
(
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
()
post_func
=
self
.
encoder
.
post_func
()
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'active'
].
view
(
-
1
),
nids
[
'enc'
])
nodes
=
g
.
filter_nodes
(
if
len
(
nodes
)
==
0
:
break
lambda
v
:
v
.
data
[
"active"
].
view
(
-
1
),
nids
[
"enc"
]
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'ee'
])
)
if
len
(
nodes
)
==
0
:
break
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"ee"
]
)
end
=
step
==
self
.
MAX_DEPTH
-
1
end
=
step
==
self
.
MAX_DEPTH
-
1
self
.
update_graph
(
g
,
edges
,
self
.
update_graph
(
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
g
,
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
'enc'
,
end
),
nodes
)])
edges
,
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
"enc"
,
end
),
nodes
)],
)
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
encoder
.
norm
(
g
.
nodes
[
nids
[
'enc'
]].
data
[
's'
])
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
encoder
.
norm
(
g
.
nodes
[
nids
[
"enc"
]].
data
[
"s"
]
)
for
step
in
range
(
self
.
MAX_DEPTH
):
for
step
in
range
(
self
.
MAX_DEPTH
):
pre_func
=
self
.
decoder
.
pre_func
(
'
qkv
'
)
pre_func
=
self
.
decoder
.
pre_func
(
"
qkv
"
)
post_func
=
self
.
decoder
.
post_func
()
post_func
=
self
.
decoder
.
post_func
()
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'active'
].
view
(
-
1
),
nids
[
'dec'
])
nodes
=
g
.
filter_nodes
(
if
len
(
nodes
)
==
0
:
break
lambda
v
:
v
.
data
[
"active"
].
view
(
-
1
),
nids
[
"dec"
]
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'dd'
])
)
self
.
update_graph
(
g
,
edges
,
if
len
(
nodes
)
==
0
:
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
break
[(
post_func
,
nodes
)])
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"dd"
]
pre_q
=
self
.
decoder
.
pre_func
(
'q'
,
1
)
)
pre_kv
=
self
.
decoder
.
pre_func
(
'kv'
,
1
)
self
.
update_graph
(
g
,
edges
,
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)],
)
pre_q
=
self
.
decoder
.
pre_func
(
"q"
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
"kv"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
1
)
post_func
=
self
.
decoder
.
post_func
(
1
)
nodes_e
=
nids
[
'enc'
]
nodes_e
=
nids
[
"enc"
]
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'ed'
])
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"ed"
]
)
end
=
step
==
self
.
MAX_DEPTH
-
1
end
=
step
==
self
.
MAX_DEPTH
-
1
self
.
update_graph
(
g
,
edges
,
self
.
update_graph
(
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
g
,
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
'dec'
,
end
),
nodes
)])
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
"dec"
,
end
),
nodes
)],
)
g
.
nodes
[
nids
[
'dec'
]].
data
[
'x'
]
=
self
.
decoder
.
norm
(
g
.
nodes
[
nids
[
'dec'
]].
data
[
's'
])
g
.
nodes
[
nids
[
"dec"
]].
data
[
"x"
]
=
self
.
decoder
.
norm
(
act_loss
=
th
.
mean
(
g
.
ndata
[
'r'
])
# ACT loss
g
.
nodes
[
nids
[
"dec"
]].
data
[
"s"
]
)
act_loss
=
th
.
mean
(
g
.
ndata
[
"r"
])
# ACT loss
self
.
stat
[
0
]
+=
N
self
.
stat
[
0
]
+=
N
for
step
in
range
(
1
,
self
.
MAX_DEPTH
+
1
):
for
step
in
range
(
1
,
self
.
MAX_DEPTH
+
1
):
self
.
stat
[
step
]
+=
th
.
sum
(
g
.
ndata
[
'
step
'
]
>=
step
).
item
()
self
.
stat
[
step
]
+=
th
.
sum
(
g
.
ndata
[
"
step
"
]
>=
step
).
item
()
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]]),
act_loss
*
self
.
act_loss_weight
return
(
self
.
generator
(
g
.
ndata
[
"x"
][
nids
[
"dec"
]]),
act_loss
*
self
.
act_loss_weight
,
)
def
infer
(
self
,
*
args
,
**
kwargs
):
def
infer
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
):
def
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
):
c
=
copy
.
deepcopy
c
=
copy
.
deepcopy
attn
=
MultiHeadAttention
(
h
,
dim_model
)
attn
=
MultiHeadAttention
(
h
,
dim_model
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
time_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
time_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
encoder
=
UEncoder
(
EncoderLayer
((
dim_model
),
c
(
attn
),
c
(
ff
),
dropout
))
encoder
=
UEncoder
(
EncoderLayer
((
dim_model
),
c
(
attn
),
c
(
ff
),
dropout
))
decoder
=
UDecoder
(
DecoderLayer
((
dim_model
),
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
))
decoder
=
UDecoder
(
DecoderLayer
((
dim_model
),
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
)
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
model
=
UTransformer
(
model
=
UTransformer
(
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
dim_model
//
h
)
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
dim_model
//
h
,
)
# xavier init
# xavier init
for
p
in
model
.
parameters
():
for
p
in
model
.
parameters
():
if
p
.
dim
()
>
1
:
if
p
.
dim
()
>
1
:
...
...
examples/pytorch/transformer/modules/models.py
View file @
704bcaf6
...
@@ -6,10 +6,12 @@ from .layers import *
...
@@ -6,10 +6,12 @@ from .layers import *
from
.functions
import
*
from
.functions
import
*
from
.embedding
import
*
from
.embedding
import
*
import
threading
import
threading
import
torch
as
th
import
dgl.function
as
fn
import
dgl.function
as
fn
import
torch
as
th
import
torch.nn.init
as
INIT
import
torch.nn.init
as
INIT
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
,
N
):
def
__init__
(
self
,
layer
,
N
):
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
...
@@ -17,24 +19,29 @@ class Encoder(nn.Module):
...
@@ -17,24 +19,29 @@ class Encoder(nn.Module):
self
.
layers
=
clones
(
layer
,
N
)
self
.
layers
=
clones
(
layer
,
N
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
i
,
fields
=
'
qkv
'
):
def
pre_func
(
self
,
i
,
fields
=
"
qkv
"
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
func
return
func
def
post_func
(
self
,
i
):
def
post_func
(
self
,
i
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
{
"x"
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
func
return
func
class
Decoder
(
nn
.
Module
):
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
,
N
):
def
__init__
(
self
,
layer
,
N
):
super
(
Decoder
,
self
).
__init__
()
super
(
Decoder
,
self
).
__init__
()
...
@@ -42,32 +49,39 @@ class Decoder(nn.Module):
...
@@ -42,32 +49,39 @@ class Decoder(nn.Module):
self
.
layers
=
clones
(
layer
,
N
)
self
.
layers
=
clones
(
layer
,
N
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
i
,
fields
=
'
qkv
'
,
l
=
0
):
def
pre_func
(
self
,
i
,
fields
=
"
qkv
"
,
l
=
0
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
if
fields
.
startswith
(
'q'
)
else
x
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
if
fields
.
startswith
(
"q"
)
else
x
if
fields
!=
'
qkv
'
:
if
fields
!=
"
qkv
"
:
return
layer
.
src_attn
.
get
(
norm_x
,
fields
)
return
layer
.
src_attn
.
get
(
norm_x
,
fields
)
else
:
else
:
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
func
return
func
def
post_func
(
self
,
i
,
l
=
0
):
def
post_func
(
self
,
i
,
l
=
0
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
if
l
==
1
:
if
l
==
1
:
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
{
"x"
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
func
return
func
class
Transformer
(
nn
.
Module
):
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
d_k
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
d_k
):
super
(
Transformer
,
self
).
__init__
()
super
(
Transformer
,
self
).
__init__
()
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
self
.
pos_enc
=
pos_enc
self
.
pos_enc
=
pos_enc
self
.
generator
=
generator
self
.
generator
=
generator
...
@@ -76,11 +90,11 @@ class Transformer(nn.Module):
...
@@ -76,11 +90,11 @@ class Transformer(nn.Module):
def
propagate_attention
(
self
,
g
,
eids
):
def
propagate_attention
(
self
,
g
,
eids
):
# Compute attention score
# Compute attention score
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'
score
'
),
eids
)
g
.
apply_edges
(
src_dot_dst
(
"k"
,
"q"
,
"
score
"
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'
score
'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
scaled_exp
(
"
score
"
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
fn
.
u_mul_e
(
'v'
,
'
score
'
,
'v'
),
fn
.
sum
(
'v'
,
'
wv
'
))
g
.
send_and_recv
(
eids
,
fn
.
u_mul_e
(
"v"
,
"
score
"
,
"v"
),
fn
.
sum
(
"v"
,
"
wv
"
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
'
score
'
,
'
score
'
),
fn
.
sum
(
'
score
'
,
'z'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
"
score
"
,
"
score
"
),
fn
.
sum
(
"
score
"
,
"z"
))
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
"Update the node states and edge states of the graph."
...
@@ -98,27 +112,44 @@ class Transformer(nn.Module):
...
@@ -98,27 +112,44 @@ class Transformer(nn.Module):
nids
,
eids
=
graph
.
nids
,
graph
.
eids
nids
,
eids
=
graph
.
nids
,
graph
.
eids
# embed
# embed
src_embed
,
src_pos
=
self
.
src_embed
(
graph
.
src
[
0
]),
self
.
pos_enc
(
graph
.
src
[
1
])
src_embed
,
src_pos
=
self
.
src_embed
(
graph
.
src
[
0
]),
self
.
pos_enc
(
tgt_embed
,
tgt_pos
=
self
.
tgt_embed
(
graph
.
tgt
[
0
]),
self
.
pos_enc
(
graph
.
tgt
[
1
])
graph
.
src
[
1
]
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
)
g
.
nodes
[
nids
[
'dec'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
tgt_embed
,
tgt_pos
=
self
.
tgt_embed
(
graph
.
tgt
[
0
]),
self
.
pos_enc
(
graph
.
tgt
[
1
]
)
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
g
.
nodes
[
nids
[
"dec"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
for
i
in
range
(
self
.
encoder
.
N
):
for
i
in
range
(
self
.
encoder
.
N
):
pre_func
=
self
.
encoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
(
i
)
post_func
=
self
.
encoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'enc'
],
eids
[
'ee'
]
nodes
,
edges
=
nids
[
"enc"
],
eids
[
"ee"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
for
i
in
range
(
self
.
decoder
.
N
):
for
i
in
range
(
self
.
decoder
.
N
):
pre_func
=
self
.
decoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
decoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
decoder
.
post_func
(
i
)
post_func
=
self
.
decoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'dec'
],
eids
[
'dd'
]
nodes
,
edges
=
nids
[
"dec"
],
eids
[
"dd"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
self
.
update_graph
(
pre_q
=
self
.
decoder
.
pre_func
(
i
,
'q'
,
1
)
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
'kv'
,
1
)
)
pre_q
=
self
.
decoder
.
pre_func
(
i
,
"q"
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
"kv"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
nodes_e
,
edges
=
nids
[
'enc'
],
eids
[
'ed'
]
nodes_e
,
edges
=
nids
[
"enc"
],
eids
[
"ed"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
)])
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
)],
)
# visualize attention
# visualize attention
"""
"""
...
@@ -126,9 +157,10 @@ class Transformer(nn.Module):
...
@@ -126,9 +157,10 @@ class Transformer(nn.Module):
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
"""
"""
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]])
return
self
.
generator
(
g
.
ndata
[
"x"
][
nids
[
"dec"
]])
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
,
alpha
=
1.0
):
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
,
alpha
=
1.0
):
'''
"""
This function implements Beam Search in DGL, which is required in inference phase.
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
args:
...
@@ -138,7 +170,7 @@ class Transformer(nn.Module):
...
@@ -138,7 +170,7 @@ class Transformer(nn.Module):
k: beam size
k: beam size
return:
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
"""
g
=
graph
.
g
g
=
graph
.
g
N
,
E
=
graph
.
n_nodes
,
graph
.
n_edges
N
,
E
=
graph
.
n_nodes
,
graph
.
n_edges
nids
,
eids
=
graph
.
nids
,
graph
.
eids
nids
,
eids
=
graph
.
nids
,
graph
.
eids
...
@@ -146,21 +178,25 @@ class Transformer(nn.Module):
...
@@ -146,21 +178,25 @@ class Transformer(nn.Module):
# embed & pos
# embed & pos
src_embed
=
self
.
src_embed
(
graph
.
src
[
0
])
src_embed
=
self
.
src_embed
(
graph
.
src
[
0
])
src_pos
=
self
.
pos_enc
(
graph
.
src
[
1
])
src_pos
=
self
.
pos_enc
(
graph
.
src
[
1
])
g
.
nodes
[
nids
[
'enc'
]].
data
[
'pos'
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
"enc"
]].
data
[
"pos"
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
tgt_pos
=
self
.
pos_enc
(
graph
.
tgt
[
1
])
tgt_pos
=
self
.
pos_enc
(
graph
.
tgt
[
1
])
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'
pos
'
]
=
graph
.
tgt
[
1
]
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"
pos
"
]
=
graph
.
tgt
[
1
]
# init mask
# init mask
device
=
next
(
self
.
parameters
()).
device
device
=
next
(
self
.
parameters
()).
device
g
.
ndata
[
'
mask
'
]
=
th
.
zeros
(
N
,
dtype
=
th
.
uint8
,
device
=
device
)
g
.
ndata
[
"
mask
"
]
=
th
.
zeros
(
N
,
dtype
=
th
.
uint8
,
device
=
device
)
# encode
# encode
for
i
in
range
(
self
.
encoder
.
N
):
for
i
in
range
(
self
.
encoder
.
N
):
pre_func
=
self
.
encoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
(
i
)
post_func
=
self
.
encoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'enc'
],
eids
[
'ee'
]
nodes
,
edges
=
nids
[
"enc"
],
eids
[
"ee"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
# decode
# decode
log_prob
=
None
log_prob
=
None
...
@@ -168,36 +204,76 @@ class Transformer(nn.Module):
...
@@ -168,36 +204,76 @@ class Transformer(nn.Module):
for
step
in
range
(
1
,
max_len
):
for
step
in
range
(
1
,
max_len
):
y
=
y
.
view
(
-
1
)
y
=
y
.
view
(
-
1
)
tgt_embed
=
self
.
tgt_embed
(
y
)
tgt_embed
=
self
.
tgt_embed
(
y
)
g
.
ndata
[
'x'
][
nids
[
'dec'
]]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
g
.
ndata
[
"x"
][
nids
[
"dec"
]]
=
self
.
pos_enc
.
dropout
(
edges_ed
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
'pos'
]
<
step
)
&
~
e
.
dst
[
'mask'
].
bool
(),
eids
[
'ed'
])
tgt_embed
+
tgt_pos
edges_dd
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
'pos'
]
<
step
)
&
~
e
.
dst
[
'mask'
].
bool
(),
eids
[
'dd'
])
)
nodes_d
=
g
.
filter_nodes
(
lambda
v
:
(
v
.
data
[
'pos'
]
<
step
)
&
~
v
.
data
[
'mask'
].
bool
(),
nids
[
'dec'
])
edges_ed
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
"pos"
]
<
step
)
&
~
e
.
dst
[
"mask"
].
bool
(),
eids
[
"ed"
],
)
edges_dd
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
"pos"
]
<
step
)
&
~
e
.
dst
[
"mask"
].
bool
(),
eids
[
"dd"
],
)
nodes_d
=
g
.
filter_nodes
(
lambda
v
:
(
v
.
data
[
"pos"
]
<
step
)
&
~
v
.
data
[
"mask"
].
bool
(),
nids
[
"dec"
],
)
for
i
in
range
(
self
.
decoder
.
N
):
for
i
in
range
(
self
.
decoder
.
N
):
pre_func
,
post_func
=
self
.
decoder
.
pre_func
(
i
,
'qkv'
),
self
.
decoder
.
post_func
(
i
)
pre_func
,
post_func
=
self
.
decoder
.
pre_func
(
i
,
"qkv"
),
self
.
decoder
.
post_func
(
i
)
nodes
,
edges
=
nodes_d
,
edges_dd
nodes
,
edges
=
nodes_d
,
edges_dd
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
self
.
update_graph
(
pre_q
,
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
'q'
,
1
),
self
.
decoder
.
pre_func
(
i
,
'kv'
,
1
)
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
pre_q
,
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
"q"
,
1
),
self
.
decoder
.
pre_func
(
i
,
"kv"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
nodes_e
,
nodes_d
,
edges
=
nids
[
'enc'
],
nodes_d
,
edges_ed
nodes_e
,
nodes_d
,
edges
=
nids
[
"enc"
],
nodes_d
,
edges_ed
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes_d
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes_d
)])
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes_d
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes_d
)],
)
frontiers
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'pos'
]
==
step
-
1
,
nids
[
'dec'
])
frontiers
=
g
.
filter_nodes
(
out
=
self
.
generator
(
g
.
ndata
[
'x'
][
frontiers
])
lambda
v
:
v
.
data
[
"pos"
]
==
step
-
1
,
nids
[
"dec"
]
)
out
=
self
.
generator
(
g
.
ndata
[
"x"
][
frontiers
])
batch_size
=
frontiers
.
shape
[
0
]
//
k
batch_size
=
frontiers
.
shape
[
0
]
//
k
vocab_size
=
out
.
shape
[
-
1
]
vocab_size
=
out
.
shape
[
-
1
]
# Mask output for complete sequence
# Mask output for complete sequence
one_hot
=
th
.
zeros
(
vocab_size
).
fill_
(
-
1e9
).
to
(
device
)
one_hot
=
th
.
zeros
(
vocab_size
).
fill_
(
-
1e9
).
to
(
device
)
one_hot
[
eos_id
]
=
0
one_hot
[
eos_id
]
=
0
mask
=
g
.
ndata
[
'
mask
'
][
frontiers
].
unsqueeze
(
-
1
).
float
()
mask
=
g
.
ndata
[
"
mask
"
][
frontiers
].
unsqueeze
(
-
1
).
float
()
out
=
out
*
(
1
-
mask
)
+
one_hot
.
unsqueeze
(
0
)
*
mask
out
=
out
*
(
1
-
mask
)
+
one_hot
.
unsqueeze
(
0
)
*
mask
if
log_prob
is
None
:
if
log_prob
is
None
:
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
eos
=
th
.
zeros
(
batch_size
,
k
).
byte
()
eos
=
th
.
zeros
(
batch_size
,
k
).
byte
()
else
:
else
:
norm_old
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
4.
+
step
)
/
6
,
alpha
)
norm_old
=
eos
.
float
().
to
(
device
)
+
(
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
5.
+
step
)
/
6
,
alpha
)
1
-
eos
.
float
().
to
(
device
)
log_prob
,
pos
=
((
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
))
/
norm_new
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
)
*
np
.
power
((
4.0
+
step
)
/
6
,
alpha
)
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
)
)
*
np
.
power
((
5.0
+
step
)
/
6
,
alpha
)
log_prob
,
pos
=
(
(
(
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
)
)
/
norm_new
.
unsqueeze
(
-
1
)
)
.
view
(
batch_size
,
-
1
)
.
topk
(
k
,
dim
=-
1
)
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
y
=
th
.
zeros_like
(
_y
)
y
=
th
.
zeros_like
(
_y
)
...
@@ -206,14 +282,16 @@ class Transformer(nn.Module):
...
@@ -206,14 +282,16 @@ class Transformer(nn.Module):
for
j
in
range
(
k
):
for
j
in
range
(
k
):
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
y
[
i
*
k
+
j
,
step
]
=
token
eos
[
i
,
j
]
=
_eos
[
i
,
_j
]
|
(
token
==
eos_id
)
eos
[
i
,
j
]
=
_eos
[
i
,
_j
]
|
(
token
==
eos_id
)
if
eos
.
all
():
if
eos
.
all
():
break
break
else
:
else
:
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
g
.
ndata
[
"mask"
][
nids
[
"dec"
]]
=
(
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
)
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
...
@@ -224,22 +302,42 @@ class Transformer(nn.Module):
...
@@ -224,22 +302,42 @@ class Transformer(nn.Module):
]
]
def
make_model
(
src_vocab
,
tgt_vocab
,
N
=
6
,
def
make_model
(
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
,
universal
=
False
):
src_vocab
,
tgt_vocab
,
N
=
6
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
,
universal
=
False
,
):
if
universal
:
if
universal
:
return
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
,
dim_ff
,
h
,
dropout
)
return
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
,
dim_ff
,
h
,
dropout
)
c
=
copy
.
deepcopy
c
=
copy
.
deepcopy
attn
=
MultiHeadAttention
(
h
,
dim_model
)
attn
=
MultiHeadAttention
(
h
,
dim_model
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
encoder
=
Encoder
(
EncoderLayer
(
dim_model
,
c
(
attn
),
c
(
ff
),
dropout
),
N
)
encoder
=
Encoder
(
EncoderLayer
(
dim_model
,
c
(
attn
),
c
(
ff
),
dropout
),
N
)
decoder
=
Decoder
(
DecoderLayer
(
dim_model
,
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
),
N
)
decoder
=
Decoder
(
DecoderLayer
(
dim_model
,
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
),
N
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
model
=
Transformer
(
model
=
Transformer
(
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
dim_model
//
h
)
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
dim_model
//
h
,
)
# xavier init
# xavier init
for
p
in
model
.
parameters
():
for
p
in
model
.
parameters
():
if
p
.
dim
()
>
1
:
if
p
.
dim
()
>
1
:
...
...
examples/pytorch/transformer/modules/viz.py
View file @
704bcaf6
import
os
import
os
import
numpy
as
np
import
torch
as
th
import
networkx
as
nx
import
matplotlib
as
mpl
import
matplotlib
as
mpl
import
matplotlib.pyplot
as
plt
import
matplotlib.animation
as
animation
import
matplotlib.animation
as
animation
import
matplotlib.pyplot
as
plt
import
networkx
as
nx
import
numpy
as
np
import
torch
as
th
from
networkx.algorithms
import
bipartite
from
networkx.algorithms
import
bipartite
def
get_attention_map
(
g
,
src_nodes
,
dst_nodes
,
h
):
def
get_attention_map
(
g
,
src_nodes
,
dst_nodes
,
h
):
"""
"""
To visualize the attention score between two set of nodes.
To visualize the attention score between two set of nodes.
...
@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
...
@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
if
not
g
.
has_edge_between
(
src
,
dst
):
if
not
g
.
has_edge_between
(
src
,
dst
):
continue
continue
eid
=
g
.
edge_ids
(
src
,
dst
)
eid
=
g
.
edge_ids
(
src
,
dst
)
weight
[
i
][
j
]
=
g
.
edata
[
'
score
'
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
[
i
][
j
]
=
g
.
edata
[
"
score
"
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
=
weight
.
transpose
(
0
,
2
)
weight
=
weight
.
transpose
(
0
,
2
)
att
=
th
.
softmax
(
weight
,
-
2
)
att
=
th
.
softmax
(
weight
,
-
2
)
return
att
.
numpy
()
return
att
.
numpy
()
def
draw_heatmap
(
array
,
input_seq
,
output_seq
,
dirname
,
name
):
def
draw_heatmap
(
array
,
input_seq
,
output_seq
,
dirname
,
name
):
dirname
=
os
.
path
.
join
(
'
log
'
,
dirname
)
dirname
=
os
.
path
.
join
(
"
log
"
,
dirname
)
if
not
os
.
path
.
exists
(
dirname
):
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
os
.
makedirs
(
dirname
)
...
@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name):
...
@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name):
axes
[
i
,
j
].
set_xticks
(
np
.
arange
(
len
(
output_seq
)))
axes
[
i
,
j
].
set_xticks
(
np
.
arange
(
len
(
output_seq
)))
axes
[
i
,
j
].
set_yticklabels
(
input_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_yticklabels
(
input_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_xticklabels
(
output_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_xticklabels
(
output_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_title
(
'head_{}'
.
format
(
cnt
),
fontsize
=
10
)
axes
[
i
,
j
].
set_title
(
"head_{}"
.
format
(
cnt
),
fontsize
=
10
)
plt
.
setp
(
axes
[
i
,
j
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
plt
.
setp
(
rotation_mode
=
"anchor"
)
axes
[
i
,
j
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
rotation_mode
=
"anchor"
,
)
cnt
+=
1
cnt
+=
1
fig
.
suptitle
(
name
,
fontsize
=
12
)
fig
.
suptitle
(
name
,
fontsize
=
12
)
plt
.
tight_layout
()
plt
.
tight_layout
()
plt
.
savefig
(
os
.
path
.
join
(
dirname
,
'
{}.pdf
'
.
format
(
name
)))
plt
.
savefig
(
os
.
path
.
join
(
dirname
,
"
{}.pdf
"
.
format
(
name
)))
plt
.
close
()
plt
.
close
()
def
draw_atts
(
maps
,
src
,
tgt
,
dirname
,
prefix
):
def
draw_atts
(
maps
,
src
,
tgt
,
dirname
,
prefix
):
'''
"""
maps[0]: encoder self-attention
maps[0]: encoder self-attention
maps[1]: encoder-decoder attention
maps[1]: encoder-decoder attention
maps[2]: decoder self-attention
maps[2]: decoder self-attention
'''
"""
draw_heatmap
(
maps
[
0
],
src
,
src
,
dirname
,
'{}_enc_self_attn'
.
format
(
prefix
))
draw_heatmap
(
maps
[
0
],
src
,
src
,
dirname
,
"{}_enc_self_attn"
.
format
(
prefix
))
draw_heatmap
(
maps
[
1
],
src
,
tgt
,
dirname
,
'{}_enc_dec_attn'
.
format
(
prefix
))
draw_heatmap
(
maps
[
1
],
src
,
tgt
,
dirname
,
"{}_enc_dec_attn"
.
format
(
prefix
))
draw_heatmap
(
maps
[
2
],
tgt
,
tgt
,
dirname
,
'{}_dec_self_attn'
.
format
(
prefix
))
draw_heatmap
(
maps
[
2
],
tgt
,
tgt
,
dirname
,
"{}_dec_self_attn"
.
format
(
prefix
))
mode2id
=
{
'
e2e
'
:
0
,
'
e2d
'
:
1
,
'
d2d
'
:
2
}
mode2id
=
{
"
e2e
"
:
0
,
"
e2d
"
:
1
,
"
d2d
"
:
2
}
colorbar
=
None
colorbar
=
None
def
att_animation
(
maps_array
,
mode
,
src
,
tgt
,
head_id
):
def
att_animation
(
maps_array
,
mode
,
src
,
tgt
,
head_id
):
weights
=
[
maps
[
mode2id
[
mode
]][
head_id
]
for
maps
in
maps_array
]
weights
=
[
maps
[
mode2id
[
mode
]][
head_id
]
for
maps
in
maps_array
]
fig
,
axes
=
plt
.
subplots
(
1
,
2
)
fig
,
axes
=
plt
.
subplots
(
1
,
2
)
...
@@ -71,75 +81,125 @@ def att_animation(maps_array, mode, src, tgt, head_id):
...
@@ -71,75 +81,125 @@ def att_animation(maps_array, mode, src, tgt, head_id):
if
colorbar
:
if
colorbar
:
colorbar
.
remove
()
colorbar
.
remove
()
plt
.
cla
()
plt
.
cla
()
axes
[
0
].
set_title
(
'
heatmap
'
)
axes
[
0
].
set_title
(
"
heatmap
"
)
axes
[
0
].
set_yticks
(
np
.
arange
(
len
(
src
)))
axes
[
0
].
set_yticks
(
np
.
arange
(
len
(
src
)))
axes
[
0
].
set_xticks
(
np
.
arange
(
len
(
tgt
)))
axes
[
0
].
set_xticks
(
np
.
arange
(
len
(
tgt
)))
axes
[
0
].
set_yticklabels
(
src
)
axes
[
0
].
set_yticklabels
(
src
)
axes
[
0
].
set_xticklabels
(
tgt
)
axes
[
0
].
set_xticklabels
(
tgt
)
plt
.
setp
(
axes
[
0
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
plt
.
setp
(
rotation_mode
=
"anchor"
)
axes
[
0
].
get_xticklabels
(),
rotation
=
45
,
fig
.
suptitle
(
'epoch {}'
.
format
(
i
))
ha
=
"right"
,
rotation_mode
=
"anchor"
,
)
fig
.
suptitle
(
"epoch {}"
.
format
(
i
))
weight
=
weights
[
i
].
transpose
(
-
1
,
-
2
)
weight
=
weights
[
i
].
transpose
(
-
1
,
-
2
)
heatmap
=
axes
[
0
].
pcolor
(
weight
,
vmin
=
0
,
vmax
=
1
,
cmap
=
plt
.
cm
.
Blues
)
heatmap
=
axes
[
0
].
pcolor
(
weight
,
vmin
=
0
,
vmax
=
1
,
cmap
=
plt
.
cm
.
Blues
)
colorbar
=
plt
.
colorbar
(
heatmap
,
ax
=
axes
[
0
],
fraction
=
0.046
,
pad
=
0.04
)
colorbar
=
plt
.
colorbar
(
heatmap
,
ax
=
axes
[
0
],
fraction
=
0.046
,
pad
=
0.04
)
axes
[
0
].
set_aspect
(
'
equal
'
)
axes
[
0
].
set_aspect
(
"
equal
"
)
axes
[
1
].
axis
(
"off"
)
axes
[
1
].
axis
(
"off"
)
graph_att_head
(
src
,
tgt
,
weight
,
axes
[
1
],
'graph'
)
graph_att_head
(
src
,
tgt
,
weight
,
axes
[
1
],
"graph"
)
ani
=
animation
.
FuncAnimation
(
ani
=
animation
.
FuncAnimation
(
fig
,
weight_animate
,
frames
=
len
(
weights
),
interval
=
500
,
repeat_delay
=
2000
)
fig
,
weight_animate
,
frames
=
len
(
weights
),
interval
=
500
,
repeat_delay
=
2000
,
)
return
ani
return
ani
def
graph_att_head
(
M
,
N
,
weight
,
ax
,
title
):
def
graph_att_head
(
M
,
N
,
weight
,
ax
,
title
):
"credit: Jinjing Zhou"
"credit: Jinjing Zhou"
in_nodes
=
len
(
M
)
in_nodes
=
len
(
M
)
out_nodes
=
len
(
N
)
out_nodes
=
len
(
N
)
g
=
nx
.
bipartite
.
generators
.
complete_bipartite_graph
(
in_nodes
,
out_nodes
)
g
=
nx
.
bipartite
.
generators
.
complete_bipartite_graph
(
in_nodes
,
out_nodes
)
X
,
Y
=
bipartite
.
sets
(
g
)
X
,
Y
=
bipartite
.
sets
(
g
)
height_in
=
10
height_in
=
10
height_out
=
height_in
height_out
=
height_in
height_in_y
=
np
.
linspace
(
0
,
height_in
,
in_nodes
)
height_in_y
=
np
.
linspace
(
0
,
height_in
,
in_nodes
)
height_out_y
=
np
.
linspace
((
height_in
-
height_out
)
/
2
,
height_out
,
out_nodes
)
height_out_y
=
np
.
linspace
(
(
height_in
-
height_out
)
/
2
,
height_out
,
out_nodes
)
pos
=
dict
()
pos
=
dict
()
pos
.
update
((
n
,
(
1
,
i
))
for
i
,
n
in
zip
(
height_in_y
,
X
))
# put nodes from X at x=1
pos
.
update
(
pos
.
update
((
n
,
(
3
,
i
))
for
i
,
n
in
zip
(
height_out_y
,
Y
))
# put nodes from Y at x=2
(
n
,
(
1
,
i
))
for
i
,
n
in
zip
(
height_in_y
,
X
)
ax
.
axis
(
'off'
)
)
# put nodes from X at x=1
ax
.
set_xlim
(
-
1
,
4
)
pos
.
update
(
(
n
,
(
3
,
i
))
for
i
,
n
in
zip
(
height_out_y
,
Y
)
)
# put nodes from Y at x=2
ax
.
axis
(
"off"
)
ax
.
set_xlim
(
-
1
,
4
)
ax
.
set_title
(
title
)
ax
.
set_title
(
title
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
),
node_color
=
'r'
,
node_size
=
50
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
,
in_nodes
+
out_nodes
),
node_color
=
'b'
,
node_size
=
50
,
ax
=
ax
)
g
,
pos
,
nodelist
=
range
(
in_nodes
),
node_color
=
"r"
,
node_size
=
50
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
,
in_nodes
+
out_nodes
),
node_color
=
"b"
,
node_size
=
50
,
ax
=
ax
,
)
for
edge
in
g
.
edges
():
for
edge
in
g
.
edges
():
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
[
edge
],
width
=
weight
[
edge
[
0
],
edge
[
1
]
-
in_nodes
]
*
1.5
,
ax
=
ax
)
nx
.
draw_networkx_edges
(
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
:
label
+
' '
for
i
,
label
in
enumerate
(
M
)},
horizontalalignment
=
'right'
,
font_size
=
8
,
ax
=
ax
)
g
,
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
+
in_nodes
:
' '
+
label
for
i
,
label
in
enumerate
(
N
)},
horizontalalignment
=
'left'
,
font_size
=
8
,
ax
=
ax
)
pos
,
edgelist
=
[
edge
],
width
=
weight
[
edge
[
0
],
edge
[
1
]
-
in_nodes
]
*
1.5
,
ax
=
ax
,
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
:
label
+
" "
for
i
,
label
in
enumerate
(
M
)},
horizontalalignment
=
"right"
,
font_size
=
8
,
ax
=
ax
,
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
+
in_nodes
:
" "
+
label
for
i
,
label
in
enumerate
(
N
)},
horizontalalignment
=
"left"
,
font_size
=
8
,
ax
=
ax
,
)
import
networkx
as
nx
import
networkx
as
nx
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
from
networkx.utils
import
is_string_like
from
networkx.utils
import
is_string_like
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
"The following function was modified from the source code of networkx"
"The following function was modified from the source code of networkx"
def
draw_networkx_edges
(
G
,
pos
,
edgelist
=
None
,
width
=
1.0
,
def
draw_networkx_edges
(
edge_color
=
'k'
,
G
,
style
=
'solid'
,
pos
,
alpha
=
1.0
,
edgelist
=
None
,
arrowstyle
=
'-|>'
,
width
=
1.0
,
arrowsize
=
10
,
edge_color
=
"k"
,
edge_cmap
=
None
,
style
=
"solid"
,
edge_vmin
=
None
,
alpha
=
1.0
,
edge_vmax
=
None
,
arrowstyle
=
"-|>"
,
ax
=
None
,
arrowsize
=
10
,
arrows
=
True
,
edge_cmap
=
None
,
label
=
None
,
edge_vmin
=
None
,
node_size
=
300
,
edge_vmax
=
None
,
nodelist
=
None
,
ax
=
None
,
node_shape
=
"o"
,
arrows
=
True
,
connectionstyle
=
'arc3'
,
label
=
None
,
**
kwds
):
node_size
=
300
,
nodelist
=
None
,
node_shape
=
"o"
,
connectionstyle
=
"arc3"
,
**
kwds
):
"""Draw the edges of the graph G.
"""Draw the edges of the graph G.
This draws only the edges of the graph G.
This draws only the edges of the graph G.
...
@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos,
...
@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos,
"""
"""
try
:
try
:
import
matplotlib
import
matplotlib
import
matplotlib.pyplot
as
plt
import
matplotlib.cbook
as
cb
import
matplotlib.cbook
as
cb
from
matplotlib.colors
import
colorConverter
,
Colormap
,
Normalize
import
matplotlib.pyplot
as
plt
from
matplotlib.collections
import
LineCollection
from
matplotlib.patches
import
FancyArrowPatch
,
ConnectionStyle
import
numpy
as
np
import
numpy
as
np
from
matplotlib.collections
import
LineCollection
from
matplotlib.colors
import
colorConverter
,
Colormap
,
Normalize
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Matplotlib required for draw()"
)
raise
ImportError
(
"Matplotlib required for draw()"
)
except
RuntimeError
:
except
RuntimeError
:
...
@@ -270,39 +330,44 @@ def draw_networkx_edges(G, pos,
...
@@ -270,39 +330,44 @@ def draw_networkx_edges(G, pos,
else
:
else
:
lw
=
width
lw
=
width
if
not
is_string_like
(
edge_color
)
\
if
(
and
cb
.
iterable
(
edge_color
)
\
not
is_string_like
(
edge_color
)
and
len
(
edge_color
)
==
len
(
edge_pos
):
and
cb
.
iterable
(
edge_color
)
and
len
(
edge_color
)
==
len
(
edge_pos
)
):
if
np
.
alltrue
([
is_string_like
(
c
)
for
c
in
edge_color
]):
if
np
.
alltrue
([
is_string_like
(
c
)
for
c
in
edge_color
]):
# (should check ALL elements)
# (should check ALL elements)
# list of color letters such as ['k','r','k',...]
# list of color letters such as ['k','r','k',...]
edge_colors
=
tuple
([
colorConverter
.
to_rgba
(
c
,
alpha
)
edge_colors
=
tuple
(
for
c
in
edge_color
])
[
colorConverter
.
to_rgba
(
c
,
alpha
)
for
c
in
edge_color
]
)
elif
np
.
alltrue
([
not
is_string_like
(
c
)
for
c
in
edge_color
]):
elif
np
.
alltrue
([
not
is_string_like
(
c
)
for
c
in
edge_color
]):
# If color specs are given as (rgb) or (rgba) tuples, we're OK
# If color specs are given as (rgb) or (rgba) tuples, we're OK
if
np
.
alltrue
([
cb
.
iterable
(
c
)
and
len
(
c
)
in
(
3
,
4
)
if
np
.
alltrue
(
for
c
in
edge_color
]):
[
cb
.
iterable
(
c
)
and
len
(
c
)
in
(
3
,
4
)
for
c
in
edge_color
]
):
edge_colors
=
tuple
(
edge_color
)
edge_colors
=
tuple
(
edge_color
)
else
:
else
:
# numbers (which are going to be mapped with a colormap)
# numbers (which are going to be mapped with a colormap)
edge_colors
=
None
edge_colors
=
None
else
:
else
:
raise
ValueError
(
'
edge_color must contain color names or numbers
'
)
raise
ValueError
(
"
edge_color must contain color names or numbers
"
)
else
:
else
:
if
is_string_like
(
edge_color
)
or
len
(
edge_color
)
==
1
:
if
is_string_like
(
edge_color
)
or
len
(
edge_color
)
==
1
:
edge_colors
=
(
colorConverter
.
to_rgba
(
edge_color
,
alpha
),
)
edge_colors
=
(
colorConverter
.
to_rgba
(
edge_color
,
alpha
),)
else
:
else
:
msg
=
'
edge_color must be a color or list of one color per edge
'
msg
=
"
edge_color must be a color or list of one color per edge
"
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
if
(
not
G
.
is_directed
()
or
not
arrows
):
if
not
G
.
is_directed
()
or
not
arrows
:
edge_collection
=
LineCollection
(
edge_pos
,
edge_collection
=
LineCollection
(
colors
=
edge_colors
,
edge_pos
,
linewidths
=
lw
,
colors
=
edge_colors
,
antialiaseds
=
(
1
,),
linewidths
=
lw
,
linestyle
=
style
,
antialiaseds
=
(
1
,),
transOffset
=
ax
.
transData
,
linestyle
=
style
,
)
transOffset
=
ax
.
transData
,
)
edge_collection
.
set_zorder
(
1
)
# edges go behind nodes
edge_collection
.
set_zorder
(
1
)
# edges go behind nodes
edge_collection
.
set_label
(
label
)
edge_collection
.
set_label
(
label
)
...
@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos,
...
@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos,
if
edge_colors
is
None
:
if
edge_colors
is
None
:
if
edge_cmap
is
not
None
:
if
edge_cmap
is
not
None
:
assert
(
isinstance
(
edge_cmap
,
Colormap
)
)
assert
isinstance
(
edge_cmap
,
Colormap
)
edge_collection
.
set_array
(
np
.
asarray
(
edge_color
))
edge_collection
.
set_array
(
np
.
asarray
(
edge_color
))
edge_collection
.
set_cmap
(
edge_cmap
)
edge_collection
.
set_cmap
(
edge_cmap
)
if
edge_vmin
is
not
None
or
edge_vmax
is
not
None
:
if
edge_vmin
is
not
None
or
edge_vmax
is
not
None
:
...
@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos,
...
@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos,
arrow_colors
=
edge_colors
arrow_colors
=
edge_colors
if
arrow_colors
is
None
:
if
arrow_colors
is
None
:
if
edge_cmap
is
not
None
:
if
edge_cmap
is
not
None
:
assert
(
isinstance
(
edge_cmap
,
Colormap
)
)
assert
isinstance
(
edge_cmap
,
Colormap
)
else
:
else
:
edge_cmap
=
plt
.
get_cmap
()
# default matplotlib colormap
edge_cmap
=
plt
.
get_cmap
()
# default matplotlib colormap
if
edge_vmin
is
None
:
if
edge_vmin
is
None
:
...
@@ -379,15 +444,18 @@ def draw_networkx_edges(G, pos,
...
@@ -379,15 +444,18 @@ def draw_networkx_edges(G, pos,
line_width
=
lw
[
i
]
line_width
=
lw
[
i
]
else
:
else
:
line_width
=
lw
[
0
]
line_width
=
lw
[
0
]
arrow
=
FancyArrowPatch
((
x1
,
y1
),
(
x2
,
y2
),
arrow
=
FancyArrowPatch
(
arrowstyle
=
arrowstyle
,
(
x1
,
y1
),
shrinkA
=
shrink_source
,
(
x2
,
y2
),
shrinkB
=
shrink_target
,
arrowstyle
=
arrowstyle
,
mutation_scale
=
mutation_scale
,
shrinkA
=
shrink_source
,
connectionstyle
=
connectionstyle
,
shrinkB
=
shrink_target
,
color
=
arrow_color
,
mutation_scale
=
mutation_scale
,
linewidth
=
line_width
,
connectionstyle
=
connectionstyle
,
zorder
=
1
)
# arrows go behind nodes
color
=
arrow_color
,
linewidth
=
line_width
,
zorder
=
1
,
)
# arrows go behind nodes
# There seems to be a bug in matplotlib to make collections of
# There seems to be a bug in matplotlib to make collections of
# FancyArrowPatch instances. Until fixed, the patches are added
# FancyArrowPatch instances. Until fixed, the patches are added
...
@@ -403,7 +471,7 @@ def draw_networkx_edges(G, pos,
...
@@ -403,7 +471,7 @@ def draw_networkx_edges(G, pos,
w
=
maxx
-
minx
w
=
maxx
-
minx
h
=
maxy
-
miny
h
=
maxy
-
miny
padx
,
pady
=
0.05
*
w
,
0.05
*
h
padx
,
pady
=
0.05
*
w
,
0.05
*
h
corners
=
(
minx
-
padx
,
miny
-
pady
),
(
maxx
+
padx
,
maxy
+
pady
)
corners
=
(
minx
-
padx
,
miny
-
pady
),
(
maxx
+
padx
,
maxy
+
pady
)
ax
.
update_datalim
(
corners
)
ax
.
update_datalim
(
corners
)
ax
.
autoscale_view
()
ax
.
autoscale_view
()
...
@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos,
...
@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos,
def
draw_g
(
graph
):
def
draw_g
(
graph
):
g
=
graph
.
g
.
to_networkx
()
g
=
graph
.
g
.
to_networkx
()
fig
=
plt
.
figure
(
figsize
=
(
8
,
4
),
dpi
=
150
)
fig
=
plt
.
figure
(
figsize
=
(
8
,
4
),
dpi
=
150
)
ax
=
fig
.
subplots
()
ax
=
fig
.
subplots
()
ax
.
axis
(
'
off
'
)
ax
.
axis
(
"
off
"
)
ax
.
set_ylim
(
-
1
,
1.5
)
ax
.
set_ylim
(
-
1
,
1.5
)
en_indx
=
graph
.
nids
[
'
enc
'
].
tolist
()
en_indx
=
graph
.
nids
[
"
enc
"
].
tolist
()
de_indx
=
graph
.
nids
[
'
dec
'
].
tolist
()
de_indx
=
graph
.
nids
[
"
dec
"
].
tolist
()
en_l
=
{
i
:
np
.
array
([
i
,
0
])
for
i
in
en_indx
}
en_l
=
{
i
:
np
.
array
([
i
,
0
])
for
i
in
en_indx
}
de_l
=
{
i
:
np
.
array
([
i
+
2
,
1
])
for
i
in
de_indx
}
de_l
=
{
i
:
np
.
array
([
i
+
2
,
1
])
for
i
in
de_indx
}
en_de_s
=
[]
en_de_s
=
[]
for
i
in
en_indx
:
for
i
in
en_indx
:
for
j
in
de_indx
:
for
j
in
de_indx
:
en_de_s
.
append
((
i
,
j
))
en_de_s
.
append
((
i
,
j
))
g
.
add_edge
(
i
,
j
)
g
.
add_edge
(
i
,
j
)
en_s
=
[]
en_s
=
[]
for
i
in
en_indx
:
for
i
in
en_indx
:
for
j
in
en_indx
:
for
j
in
en_indx
:
g
.
add_edge
(
i
,
j
)
g
.
add_edge
(
i
,
j
)
en_s
.
append
((
i
,
j
))
en_s
.
append
((
i
,
j
))
de_s
=
[]
de_s
=
[]
for
idx
,
i
in
enumerate
(
de_indx
):
for
idx
,
i
in
enumerate
(
de_indx
):
for
j
in
de_indx
[
idx
:]:
for
j
in
de_indx
[
idx
:]:
g
.
add_edge
(
i
,
j
)
g
.
add_edge
(
i
,
j
)
de_s
.
append
((
i
,
j
))
de_s
.
append
((
i
,
j
))
nx
.
draw_networkx_nodes
(
nx
.
draw_networkx_nodes
(
g
,
en_l
,
nodelist
=
en_indx
,
node_color
=
'r'
,
node_size
=
60
,
ax
=
ax
)
g
,
en_l
,
nodelist
=
en_indx
,
node_color
=
"r"
,
node_size
=
60
,
ax
=
ax
nx
.
draw_networkx_nodes
(
g
,
de_l
,
nodelist
=
de_indx
,
node_color
=
'r'
,
node_size
=
60
,
ax
=
ax
)
)
draw_networkx_edges
(
g
,
en_l
,
edgelist
=
en_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
)
nx
.
draw_networkx_nodes
(
draw_networkx_edges
(
g
,
de_l
,
edgelist
=
de_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
)
g
,
de_l
,
nodelist
=
de_indx
,
node_color
=
"r"
,
node_size
=
60
,
ax
=
ax
draw_networkx_edges
(
g
,{
**
en_l
,
**
de_l
},
edgelist
=
en_de_s
,
width
=
0.3
,
ax
=
ax
)
)
draw_networkx_edges
(
g
,
en_l
,
edgelist
=
en_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
,
)
draw_networkx_edges
(
g
,
de_l
,
edgelist
=
de_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
,
)
draw_networkx_edges
(
g
,
{
**
en_l
,
**
de_l
},
edgelist
=
en_de_s
,
width
=
0.3
,
ax
=
ax
)
# ax.add_patch()
# ax.add_patch()
ax
.
text
(
len
(
en_indx
)
+
0.5
,
0
,
"Encoder"
,
verticalalignment
=
'center'
,
horizontalalignment
=
'left'
)
ax
.
text
(
len
(
en_indx
)
+
0.5
,
ax
.
text
(
len
(
en_indx
)
+
0.5
,
1
,
"Decoder"
,
verticalalignment
=
'center'
,
horizontalalignment
=
'right'
)
0
,
delta
=
0.03
"Encoder"
,
for
value
in
{
**
en_l
,
**
de_l
}.
values
():
verticalalignment
=
"center"
,
x
,
y
=
value
horizontalalignment
=
"left"
,
ax
.
add_patch
(
FancyArrowPatch
((
x
-
delta
,
y
+
delta
),(
x
-
delta
,
y
-
delta
),
arrowstyle
=
"->"
,
mutation_scale
=
8
,
connectionstyle
=
"arc3,rad=3"
))
)
ax
.
text
(
len
(
en_indx
)
+
0.5
,
1
,
"Decoder"
,
verticalalignment
=
"center"
,
horizontalalignment
=
"right"
,
)
delta
=
0.03
for
value
in
{
**
en_l
,
**
de_l
}.
values
():
x
,
y
=
value
ax
.
add_patch
(
FancyArrowPatch
(
(
x
-
delta
,
y
+
delta
),
(
x
-
delta
,
y
-
delta
),
arrowstyle
=
"->"
,
mutation_scale
=
8
,
connectionstyle
=
"arc3,rad=3"
,
)
)
plt
.
show
(
fig
)
plt
.
show
(
fig
)
examples/pytorch/tree_lstm/train.py
View file @
704bcaf6
...
@@ -2,17 +2,17 @@ import argparse
...
@@ -2,17 +2,17 @@ import argparse
import
collections
import
collections
import
time
import
time
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.nn.init
as
INIT
import
torch.nn.init
as
INIT
import
torch.optim
as
optim
import
torch.optim
as
optim
from
dgl.data.tree
import
SSTDataset
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tree_lstm
import
TreeLSTM
from
tree_lstm
import
TreeLSTM
import
dgl
from
dgl.data.tree
import
SSTDataset
SSTBatch
=
collections
.
namedtuple
(
SSTBatch
=
collections
.
namedtuple
(
"SSTBatch"
,
[
"graph"
,
"mask"
,
"wordid"
,
"label"
]
"SSTBatch"
,
[
"graph"
,
"mask"
,
"wordid"
,
"label"
]
)
)
...
...
examples/pytorch/tree_lstm/tree_lstm.py
View file @
704bcaf6
...
@@ -2,14 +2,16 @@
...
@@ -2,14 +2,16 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
https://arxiv.org/abs/1503.00075
"""
"""
import
time
import
itertools
import
itertools
import
time
import
dgl
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
class
TreeLSTMCell
(
nn
.
Module
):
class
TreeLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
x_size
,
h_size
):
def
__init__
(
self
,
x_size
,
h_size
):
...
@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module):
...
@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module):
self
.
U_f
=
nn
.
Linear
(
2
*
h_size
,
2
*
h_size
)
self
.
U_f
=
nn
.
Linear
(
2
*
h_size
,
2
*
h_size
)
def
message_func
(
self
,
edges
):
def
message_func
(
self
,
edges
):
return
{
'h'
:
edges
.
src
[
'h'
],
'c'
:
edges
.
src
[
'c'
]}
return
{
"h"
:
edges
.
src
[
"h"
],
"c"
:
edges
.
src
[
"c"
]}
def
reduce_func
(
self
,
nodes
):
def
reduce_func
(
self
,
nodes
):
h_cat
=
nodes
.
mailbox
[
'h'
].
view
(
nodes
.
mailbox
[
'h'
].
size
(
0
),
-
1
)
h_cat
=
nodes
.
mailbox
[
"h"
].
view
(
nodes
.
mailbox
[
"h"
].
size
(
0
),
-
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
h_cat
)).
view
(
*
nodes
.
mailbox
[
'h'
].
size
())
f
=
th
.
sigmoid
(
self
.
U_f
(
h_cat
)).
view
(
*
nodes
.
mailbox
[
"h"
].
size
())
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
'c'
],
1
)
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
"c"
],
1
)
return
{
'
iou
'
:
self
.
U_iou
(
h_cat
),
'c'
:
c
}
return
{
"
iou
"
:
self
.
U_iou
(
h_cat
),
"c"
:
c
}
def
apply_node_func
(
self
,
nodes
):
def
apply_node_func
(
self
,
nodes
):
iou
=
nodes
.
data
[
'
iou
'
]
+
self
.
b_iou
iou
=
nodes
.
data
[
"
iou
"
]
+
self
.
b_iou
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
c
=
i
*
u
+
nodes
.
data
[
'c'
]
c
=
i
*
u
+
nodes
.
data
[
"c"
]
h
=
o
*
th
.
tanh
(
c
)
h
=
o
*
th
.
tanh
(
c
)
return
{
'h'
:
h
,
'c'
:
c
}
return
{
"h"
:
h
,
"c"
:
c
}
class
ChildSumTreeLSTMCell
(
nn
.
Module
):
class
ChildSumTreeLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
x_size
,
h_size
):
def
__init__
(
self
,
x_size
,
h_size
):
...
@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module):
...
@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module):
self
.
U_f
=
nn
.
Linear
(
h_size
,
h_size
)
self
.
U_f
=
nn
.
Linear
(
h_size
,
h_size
)
def
message_func
(
self
,
edges
):
def
message_func
(
self
,
edges
):
return
{
'h'
:
edges
.
src
[
'h'
],
'c'
:
edges
.
src
[
'c'
]}
return
{
"h"
:
edges
.
src
[
"h"
],
"c"
:
edges
.
src
[
"c"
]}
def
reduce_func
(
self
,
nodes
):
def
reduce_func
(
self
,
nodes
):
h_tild
=
th
.
sum
(
nodes
.
mailbox
[
'h'
],
1
)
h_tild
=
th
.
sum
(
nodes
.
mailbox
[
"h"
],
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
nodes
.
mailbox
[
'h'
]))
f
=
th
.
sigmoid
(
self
.
U_f
(
nodes
.
mailbox
[
"h"
]))
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
'c'
],
1
)
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
"c"
],
1
)
return
{
'
iou
'
:
self
.
U_iou
(
h_tild
),
'c'
:
c
}
return
{
"
iou
"
:
self
.
U_iou
(
h_tild
),
"c"
:
c
}
def
apply_node_func
(
self
,
nodes
):
def
apply_node_func
(
self
,
nodes
):
iou
=
nodes
.
data
[
'
iou
'
]
+
self
.
b_iou
iou
=
nodes
.
data
[
"
iou
"
]
+
self
.
b_iou
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
c
=
i
*
u
+
nodes
.
data
[
'c'
]
c
=
i
*
u
+
nodes
.
data
[
"c"
]
h
=
o
*
th
.
tanh
(
c
)
h
=
o
*
th
.
tanh
(
c
)
return
{
'h'
:
h
,
'c'
:
c
}
return
{
"h"
:
h
,
"c"
:
c
}
class
TreeLSTM
(
nn
.
Module
):
class
TreeLSTM
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
num_vocabs
,
self
,
x_size
,
num_vocabs
,
h_size
,
x_size
,
num_classes
,
h_size
,
dropout
,
num_classes
,
cell_type
=
'nary'
,
dropout
,
pretrained_emb
=
None
):
cell_type
=
"nary"
,
pretrained_emb
=
None
,
):
super
(
TreeLSTM
,
self
).
__init__
()
super
(
TreeLSTM
,
self
).
__init__
()
self
.
x_size
=
x_size
self
.
x_size
=
x_size
self
.
embedding
=
nn
.
Embedding
(
num_vocabs
,
x_size
)
self
.
embedding
=
nn
.
Embedding
(
num_vocabs
,
x_size
)
if
pretrained_emb
is
not
None
:
if
pretrained_emb
is
not
None
:
print
(
'
Using glove
'
)
print
(
"
Using glove
"
)
self
.
embedding
.
weight
.
data
.
copy_
(
pretrained_emb
)
self
.
embedding
.
weight
.
data
.
copy_
(
pretrained_emb
)
self
.
embedding
.
weight
.
requires_grad
=
True
self
.
embedding
.
weight
.
requires_grad
=
True
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
linear
=
nn
.
Linear
(
h_size
,
num_classes
)
self
.
linear
=
nn
.
Linear
(
h_size
,
num_classes
)
cell
=
TreeLSTMCell
if
cell_type
==
'
nary
'
else
ChildSumTreeLSTMCell
cell
=
TreeLSTMCell
if
cell_type
==
"
nary
"
else
ChildSumTreeLSTMCell
self
.
cell
=
cell
(
x_size
,
h_size
)
self
.
cell
=
cell
(
x_size
,
h_size
)
def
forward
(
self
,
batch
,
g
,
h
,
c
):
def
forward
(
self
,
batch
,
g
,
h
,
c
):
...
@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module):
...
@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module):
"""
"""
# feed embedding
# feed embedding
embeds
=
self
.
embedding
(
batch
.
wordid
*
batch
.
mask
)
embeds
=
self
.
embedding
(
batch
.
wordid
*
batch
.
mask
)
g
.
ndata
[
'iou'
]
=
self
.
cell
.
W_iou
(
self
.
dropout
(
embeds
))
*
batch
.
mask
.
float
().
unsqueeze
(
-
1
)
g
.
ndata
[
"iou"
]
=
self
.
cell
.
W_iou
(
g
.
ndata
[
'h'
]
=
h
self
.
dropout
(
embeds
)
g
.
ndata
[
'c'
]
=
c
)
*
batch
.
mask
.
float
().
unsqueeze
(
-
1
)
g
.
ndata
[
"h"
]
=
h
g
.
ndata
[
"c"
]
=
c
# propagate
# propagate
dgl
.
prop_nodes_topo
(
g
,
self
.
cell
.
message_func
,
self
.
cell
.
reduce_func
,
apply_node_func
=
self
.
cell
.
apply_node_func
)
dgl
.
prop_nodes_topo
(
g
,
self
.
cell
.
message_func
,
self
.
cell
.
reduce_func
,
apply_node_func
=
self
.
cell
.
apply_node_func
,
)
# compute logits
# compute logits
h
=
self
.
dropout
(
g
.
ndata
.
pop
(
'h'
))
h
=
self
.
dropout
(
g
.
ndata
.
pop
(
"h"
))
logits
=
self
.
linear
(
h
)
logits
=
self
.
linear
(
h
)
return
logits
return
logits
examples/pytorch/vgae/model.py
View file @
704bcaf6
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
train
import
device
from
dgl.nn.pytorch
import
GraphConv
from
dgl.nn.pytorch
import
GraphConv
from
train
import
device
class
VGAEModel
(
nn
.
Module
):
class
VGAEModel
(
nn
.
Module
):
...
...
examples/pytorch/vgae/train.py
View file @
704bcaf6
...
@@ -2,11 +2,14 @@ import argparse
...
@@ -2,11 +2,14 @@ import argparse
import
os
import
os
import
time
import
time
import
dgl
import
model
import
model
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
input_data
import
load_data
from
input_data
import
load_data
from
preprocess
import
(
from
preprocess
import
(
mask_test_edges
,
mask_test_edges
,
...
@@ -16,9 +19,6 @@ from preprocess import (
...
@@ -16,9 +19,6 @@ from preprocess import (
)
)
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
import
dgl
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"True"
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"True"
parser
=
argparse
.
ArgumentParser
(
description
=
"Variant Graph Auto Encoder"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Variant Graph Auto Encoder"
)
...
...
examples/pytorch/vrgcn/train_cv.py
View file @
704bcaf6
import
argparse
import
argparse
import
time
import
time
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
torch.utils.data
import
DataLoader
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
from
dgl.data
import
RedditDataset
from
dgl.data
import
RedditDataset
from
torch.utils.data
import
DataLoader
class
SAGEConvWithCV
(
nn
.
Module
):
class
SAGEConvWithCV
(
nn
.
Module
):
...
...
examples/pytorch/vrgcn/train_cv_multi_gpu.py
View file @
704bcaf6
...
@@ -3,6 +3,10 @@ import math
...
@@ -3,6 +3,10 @@ import math
import
time
import
time
import
traceback
import
traceback
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -10,14 +14,10 @@ import torch.nn as nn
...
@@ -10,14 +14,10 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
dgl.data
import
RedditDataset
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
from
dgl.data
import
RedditDataset
class
SAGEConvWithCV
(
nn
.
Module
):
class
SAGEConvWithCV
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
):
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
):
...
...
examples/sparse/c_and_s.py
View file @
704bcaf6
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
dgl.data
import
CoraGraphDataset
from
dgl.data
import
CoraGraphDataset
from
torch.optim
import
Adam
from
torch.optim
import
Adam
###############################################################################
###############################################################################
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
###############################################################################
###############################################################################
...
...
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment