Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
635 additions
and
341 deletions
+635
-341
examples/pytorch/sagpool/network.py
examples/pytorch/sagpool/network.py
+2
-3
examples/pytorch/seal/logger.py
examples/pytorch/seal/logger.py
+0
-1
examples/pytorch/seal/main.py
examples/pytorch/seal/main.py
+3
-3
examples/pytorch/seal/sampler.py
examples/pytorch/seal/sampler.py
+4
-7
examples/pytorch/seal/utils.py
examples/pytorch/seal/utils.py
+2
-2
examples/pytorch/sgc/sgc.py
examples/pytorch/sgc/sgc.py
+3
-3
examples/pytorch/sgc/sgc_reddit.py
examples/pytorch/sgc/sgc_reddit.py
+2
-2
examples/pytorch/sign/dataset.py
examples/pytorch/sign/dataset.py
+1
-2
examples/pytorch/sign/sign.py
examples/pytorch/sign/sign.py
+3
-3
examples/pytorch/tagcn/train.py
examples/pytorch/tagcn/train.py
+1
-1
examples/pytorch/transformer/modules/act.py
examples/pytorch/transformer/modules/act.py
+152
-69
examples/pytorch/transformer/modules/models.py
examples/pytorch/transformer/modules/models.py
+163
-65
examples/pytorch/transformer/modules/viz.py
examples/pytorch/transformer/modules/viz.py
+235
-130
examples/pytorch/tree_lstm/train.py
examples/pytorch/tree_lstm/train.py
+3
-3
examples/pytorch/tree_lstm/tree_lstm.py
examples/pytorch/tree_lstm/tree_lstm.py
+46
-33
examples/pytorch/vgae/model.py
examples/pytorch/vgae/model.py
+1
-1
examples/pytorch/vgae/train.py
examples/pytorch/vgae/train.py
+3
-3
examples/pytorch/vrgcn/train_cv.py
examples/pytorch/vrgcn/train_cv.py
+5
-5
examples/pytorch/vrgcn/train_cv_multi_gpu.py
examples/pytorch/vrgcn/train_cv_multi_gpu.py
+5
-5
examples/sparse/c_and_s.py
examples/sparse/c_and_s.py
+1
-0
No files found.
examples/pytorch/sagpool/network.py
View file @
704bcaf6
import
dgl
import
torch
import
torch.nn
import
torch.nn.functional
as
F
from
layer
import
ConvPoolBlock
,
SAGPool
import
dgl
from
dgl.nn
import
AvgPooling
,
GraphConv
,
MaxPooling
from
layer
import
ConvPoolBlock
,
SAGPool
class
SAGNetworkHierarchical
(
torch
.
nn
.
Module
):
...
...
examples/pytorch/seal/logger.py
View file @
704bcaf6
...
...
@@ -20,7 +20,6 @@ def _transform_log_level(str_level):
class
LightLogging
(
object
):
def
__init__
(
self
,
log_path
=
None
,
log_name
=
"lightlog"
,
log_level
=
"debug"
):
log_level
=
_transform_log_level
(
log_level
)
if
log_path
:
...
...
examples/pytorch/seal/main.py
View file @
704bcaf6
...
...
@@ -3,6 +3,9 @@ import time
import
numpy
as
np
import
torch
import
torch.multiprocessing
from
dgl
import
EID
,
NID
from
dgl.dataloading
import
GraphDataLoader
from
logger
import
LightLogging
from
model
import
DGCNN
,
GCN
from
sampler
import
SEALData
...
...
@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss
from
tqdm
import
tqdm
from
utils
import
evaluate_hits
,
load_ogb_dataset
,
parse_arguments
from
dgl
import
EID
,
NID
from
dgl.dataloading
import
GraphDataLoader
torch
.
multiprocessing
.
set_sharing_strategy
(
"file_system"
)
"""
...
...
examples/pytorch/seal/sampler.py
View file @
704bcaf6
import
os.path
as
osp
from
copy
import
deepcopy
import
dgl
import
torch
from
dgl
import
add_self_loop
,
DGLGraph
,
NID
from
dgl.dataloading.negative_sampler
import
Uniform
from
torch.utils.data
import
DataLoader
,
Dataset
from
tqdm
import
tqdm
from
utils
import
drnl_node_labeling
import
dgl
from
dgl
import
NID
,
DGLGraph
,
add_self_loop
from
dgl.dataloading.negative_sampler
import
Uniform
class
GraphDataSet
(
Dataset
):
"""
...
...
@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object):
self
.
shuffle
=
shuffle
def
__call__
(
self
,
split_type
):
if
split_type
==
"train"
:
subsample_ratio
=
self
.
subsample_ratio
else
:
...
...
@@ -177,7 +176,6 @@ class SEALSampler(object):
return
subgraph
def
_collate
(
self
,
batch
):
batch_graphs
,
batch_labels
=
map
(
list
,
zip
(
*
batch
))
batch_graphs
=
dgl
.
batch
(
batch_graphs
)
...
...
@@ -272,7 +270,6 @@ class SEALData(object):
)
def
__call__
(
self
,
split_type
):
if
split_type
==
"train"
:
subsample_ratio
=
self
.
subsample_ratio
else
:
...
...
examples/pytorch/seal/utils.py
View file @
704bcaf6
import
argparse
import
dgl
import
numpy
as
np
import
pandas
as
pd
import
torch
from
ogb.linkproppred
import
DglLinkPropPredDataset
,
Evaluator
from
scipy.sparse.csgraph
import
shortest_path
import
dgl
def
parse_arguments
():
"""
...
...
examples/pytorch/sgc/sgc.py
View file @
704bcaf6
...
...
@@ -9,13 +9,13 @@ import argparse
import
math
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
import
dgl.function
as
fn
from
dgl.data
import
(
CiteseerGraphDataset
,
CoraGraphDataset
,
...
...
examples/pytorch/sgc/sgc_reddit.py
View file @
704bcaf6
...
...
@@ -9,12 +9,12 @@ import argparse
import
math
import
time
import
dgl.function
as
fn
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl.function
as
fn
from
dgl
import
DGLGraph
from
dgl.data
import
load_data
,
register_data_args
from
dgl.nn.pytorch.conv
import
SGConv
...
...
examples/pytorch/sign/dataset.py
View file @
704bcaf6
import
dgl
import
numpy
as
np
import
torch
import
dgl
def
load_dataset
(
name
):
dataset
=
name
.
lower
()
...
...
examples/pytorch/sign/sign.py
View file @
704bcaf6
...
...
@@ -2,14 +2,14 @@ import argparse
import
os
import
time
import
dgl
import
dgl.function
as
fn
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
dataset
import
load_dataset
import
dgl
import
dgl.function
as
fn
class
FeedForwardNet
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
hidden
,
out_feats
,
n_layers
,
dropout
):
...
...
examples/pytorch/tagcn/train.py
View file @
704bcaf6
...
...
@@ -6,10 +6,10 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
tagcn
import
TAGCN
from
dgl
import
DGLGraph
from
dgl.data
import
load_data
,
register_data_args
from
tagcn
import
TAGCN
def
evaluate
(
model
,
features
,
labels
,
mask
):
...
...
examples/pytorch/transformer/modules/act.py
View file @
704bcaf6
...
...
@@ -2,32 +2,37 @@ from .attention import *
from
.layers
import
*
from
.functions
import
*
from
.embedding
import
*
import
torch
as
th
import
dgl.function
as
fn
import
torch
as
th
import
torch.nn.init
as
INIT
class
UEncoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
):
super
(
UEncoder
,
self
).
__init__
()
self
.
layer
=
layer
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
fields
=
'
qkv
'
):
def
pre_func
(
self
,
fields
=
"
qkv
"
):
layer
=
self
.
layer
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
func
def
post_func
(
self
):
layer
=
self
.
layer
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
}
return
{
"x"
:
x
}
return
func
...
...
@@ -37,31 +42,36 @@ class UDecoder(nn.Module):
self
.
layer
=
layer
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
fields
=
'
qkv
'
,
l
=
0
):
def
pre_func
(
self
,
fields
=
"
qkv
"
,
l
=
0
):
layer
=
self
.
layer
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
if
fields
==
'
kv
'
:
x
=
nodes
.
data
[
"x"
]
if
fields
==
"
kv
"
:
norm_x
=
x
else
:
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
func
def
post_func
(
self
,
l
=
0
):
layer
=
self
.
layer
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
if
l
==
1
:
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
}
return
{
"x"
:
x
}
return
func
class
HaltingUnit
(
nn
.
Module
):
halting_bias_init
=
1.0
def
__init__
(
self
,
dim_model
):
super
(
HaltingUnit
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
dim_model
,
1
)
...
...
@@ -71,12 +81,25 @@ class HaltingUnit(nn.Module):
def
forward
(
self
,
x
):
return
th
.
sigmoid
(
self
.
linear
(
self
.
norm
(
x
)))
class
UTransformer
(
nn
.
Module
):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH
=
8
thres
=
0.99
act_loss_weight
=
0.01
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
d_k
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
d_k
,
):
super
(
UTransformer
,
self
).
__init__
()
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
...
...
@@ -91,34 +114,45 @@ class UTransformer(nn.Module):
self
.
stat
=
[
0
]
*
(
self
.
MAX_DEPTH
+
1
)
def
step_forward
(
self
,
nodes
):
x
=
nodes
.
data
[
'x'
]
step
=
nodes
.
data
[
'step'
]
pos
=
nodes
.
data
[
'pos'
]
return
{
'x'
:
self
.
pos_enc
.
dropout
(
x
+
self
.
pos_enc
(
pos
.
view
(
-
1
))
+
self
.
time_enc
(
step
.
view
(
-
1
))),
'step'
:
step
+
1
}
x
=
nodes
.
data
[
"x"
]
step
=
nodes
.
data
[
"step"
]
pos
=
nodes
.
data
[
"pos"
]
return
{
"x"
:
self
.
pos_enc
.
dropout
(
x
+
self
.
pos_enc
(
pos
.
view
(
-
1
))
+
self
.
time_enc
(
step
.
view
(
-
1
))
),
"step"
:
step
+
1
,
}
def
halt_and_accum
(
self
,
name
,
end
=
False
):
"field: 'enc' or 'dec'"
halt
=
self
.
halt_enc
if
name
==
'
enc
'
else
self
.
halt_dec
halt
=
self
.
halt_enc
if
name
==
"
enc
"
else
self
.
halt_dec
thres
=
self
.
thres
def
func
(
nodes
):
p
=
halt
(
nodes
.
data
[
'x'
])
sum_p
=
nodes
.
data
[
'
sum_p
'
]
+
p
p
=
halt
(
nodes
.
data
[
"x"
])
sum_p
=
nodes
.
data
[
"
sum_p
"
]
+
p
active
=
(
sum_p
<
thres
)
&
(
1
-
end
)
_continue
=
active
.
float
()
r
=
nodes
.
data
[
'r'
]
*
(
1
-
_continue
)
+
(
1
-
sum_p
)
*
_continue
s
=
nodes
.
data
[
's'
]
+
((
1
-
_continue
)
*
r
+
_continue
*
p
)
*
nodes
.
data
[
'x'
]
return
{
'p'
:
p
,
'sum_p'
:
sum_p
,
'r'
:
r
,
's'
:
s
,
'active'
:
active
}
r
=
nodes
.
data
[
"r"
]
*
(
1
-
_continue
)
+
(
1
-
sum_p
)
*
_continue
s
=
(
nodes
.
data
[
"s"
]
+
((
1
-
_continue
)
*
r
+
_continue
*
p
)
*
nodes
.
data
[
"x"
]
)
return
{
"p"
:
p
,
"sum_p"
:
sum_p
,
"r"
:
r
,
"s"
:
s
,
"active"
:
active
}
return
func
def
propagate_attention
(
self
,
g
,
eids
):
# Compute attention score
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'
score
'
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'
score
'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
src_dot_dst
(
"k"
,
"q"
,
"
score
"
),
eids
)
g
.
apply_edges
(
scaled_exp
(
"
score
"
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
[
fn
.
u_mul_e
(
'v'
,
'score'
,
'v'
),
fn
.
copy_e
(
'score'
,
'score'
)],
[
fn
.
sum
(
'v'
,
'wv'
),
fn
.
sum
(
'score'
,
'z'
)])
g
.
send_and_recv
(
eids
,
[
fn
.
u_mul_e
(
"v"
,
"score"
,
"v"
),
fn
.
copy_e
(
"score"
,
"score"
)],
[
fn
.
sum
(
"v"
,
"wv"
),
fn
.
sum
(
"score"
,
"z"
)],
)
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
...
...
@@ -136,79 +170,128 @@ class UTransformer(nn.Module):
nids
,
eids
=
graph
.
nids
,
graph
.
eids
# embed & pos
g
.
nodes
[
nids
[
'
enc
'
]].
data
[
'x'
]
=
self
.
src_embed
(
graph
.
src
[
0
])
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'x'
]
=
self
.
tgt_embed
(
graph
.
tgt
[
0
])
g
.
nodes
[
nids
[
'
enc
'
]].
data
[
'
pos
'
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'
pos
'
]
=
graph
.
tgt
[
1
]
g
.
nodes
[
nids
[
"
enc
"
]].
data
[
"x"
]
=
self
.
src_embed
(
graph
.
src
[
0
])
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"x"
]
=
self
.
tgt_embed
(
graph
.
tgt
[
0
])
g
.
nodes
[
nids
[
"
enc
"
]].
data
[
"
pos
"
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"
pos
"
]
=
graph
.
tgt
[
1
]
# init step
device
=
next
(
self
.
parameters
()).
device
g
.
ndata
[
's'
]
=
th
.
zeros
(
N
,
self
.
h
*
self
.
d_k
,
dtype
=
th
.
float
,
device
=
device
)
# accumulated state
g
.
ndata
[
'p'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# halting prob
g
.
ndata
[
'r'
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# remainder
g
.
ndata
[
'sum_p'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# sum of pondering values
g
.
ndata
[
'step'
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
long
,
device
=
device
)
# step
g
.
ndata
[
'active'
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
uint8
,
device
=
device
)
# active
g
.
ndata
[
"s"
]
=
th
.
zeros
(
N
,
self
.
h
*
self
.
d_k
,
dtype
=
th
.
float
,
device
=
device
)
# accumulated state
g
.
ndata
[
"p"
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# halting prob
g
.
ndata
[
"r"
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# remainder
g
.
ndata
[
"sum_p"
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
float
,
device
=
device
)
# sum of pondering values
g
.
ndata
[
"step"
]
=
th
.
zeros
(
N
,
1
,
dtype
=
th
.
long
,
device
=
device
)
# step
g
.
ndata
[
"active"
]
=
th
.
ones
(
N
,
1
,
dtype
=
th
.
uint8
,
device
=
device
)
# active
for
step
in
range
(
self
.
MAX_DEPTH
):
pre_func
=
self
.
encoder
.
pre_func
(
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
()
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'active'
].
view
(
-
1
),
nids
[
'enc'
])
if
len
(
nodes
)
==
0
:
break
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'ee'
])
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
"active"
].
view
(
-
1
),
nids
[
"enc"
]
)
if
len
(
nodes
)
==
0
:
break
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"ee"
]
)
end
=
step
==
self
.
MAX_DEPTH
-
1
self
.
update_graph
(
g
,
edges
,
self
.
update_graph
(
g
,
edges
,
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
'enc'
,
end
),
nodes
)])
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
"enc"
,
end
),
nodes
)],
)
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
encoder
.
norm
(
g
.
nodes
[
nids
[
'enc'
]].
data
[
's'
])
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
encoder
.
norm
(
g
.
nodes
[
nids
[
"enc"
]].
data
[
"s"
]
)
for
step
in
range
(
self
.
MAX_DEPTH
):
pre_func
=
self
.
decoder
.
pre_func
(
'
qkv
'
)
pre_func
=
self
.
decoder
.
pre_func
(
"
qkv
"
)
post_func
=
self
.
decoder
.
post_func
()
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'active'
].
view
(
-
1
),
nids
[
'dec'
])
if
len
(
nodes
)
==
0
:
break
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'dd'
])
self
.
update_graph
(
g
,
edges
,
nodes
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
"active"
].
view
(
-
1
),
nids
[
"dec"
]
)
if
len
(
nodes
)
==
0
:
break
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"dd"
]
)
self
.
update_graph
(
g
,
edges
,
[(
self
.
step_forward
,
nodes
),
(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
[(
post_func
,
nodes
)],
)
pre_q
=
self
.
decoder
.
pre_func
(
'q'
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
'
kv
'
,
1
)
pre_q
=
self
.
decoder
.
pre_func
(
"q"
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
"
kv
"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
1
)
nodes_e
=
nids
[
'enc'
]
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
'active'
].
view
(
-
1
),
eids
[
'ed'
])
nodes_e
=
nids
[
"enc"
]
edges
=
g
.
filter_edges
(
lambda
e
:
e
.
dst
[
"active"
].
view
(
-
1
),
eids
[
"ed"
]
)
end
=
step
==
self
.
MAX_DEPTH
-
1
self
.
update_graph
(
g
,
edges
,
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
'dec'
,
end
),
nodes
)])
[(
post_func
,
nodes
),
(
self
.
halt_and_accum
(
"dec"
,
end
),
nodes
)],
)
g
.
nodes
[
nids
[
'dec'
]].
data
[
'x'
]
=
self
.
decoder
.
norm
(
g
.
nodes
[
nids
[
'dec'
]].
data
[
's'
])
act_loss
=
th
.
mean
(
g
.
ndata
[
'r'
])
# ACT loss
g
.
nodes
[
nids
[
"dec"
]].
data
[
"x"
]
=
self
.
decoder
.
norm
(
g
.
nodes
[
nids
[
"dec"
]].
data
[
"s"
]
)
act_loss
=
th
.
mean
(
g
.
ndata
[
"r"
])
# ACT loss
self
.
stat
[
0
]
+=
N
for
step
in
range
(
1
,
self
.
MAX_DEPTH
+
1
):
self
.
stat
[
step
]
+=
th
.
sum
(
g
.
ndata
[
'
step
'
]
>=
step
).
item
()
self
.
stat
[
step
]
+=
th
.
sum
(
g
.
ndata
[
"
step
"
]
>=
step
).
item
()
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]]),
act_loss
*
self
.
act_loss_weight
return
(
self
.
generator
(
g
.
ndata
[
"x"
][
nids
[
"dec"
]]),
act_loss
*
self
.
act_loss_weight
,
)
def
infer
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
):
def
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
):
c
=
copy
.
deepcopy
attn
=
MultiHeadAttention
(
h
,
dim_model
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
time_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
encoder
=
UEncoder
(
EncoderLayer
((
dim_model
),
c
(
attn
),
c
(
ff
),
dropout
))
decoder
=
UDecoder
(
DecoderLayer
((
dim_model
),
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
))
decoder
=
UDecoder
(
DecoderLayer
((
dim_model
),
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
)
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
model
=
UTransformer
(
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
dim_model
//
h
)
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
time_enc
,
generator
,
h
,
dim_model
//
h
,
)
# xavier init
for
p
in
model
.
parameters
():
if
p
.
dim
()
>
1
:
...
...
examples/pytorch/transformer/modules/models.py
View file @
704bcaf6
...
...
@@ -6,10 +6,12 @@ from .layers import *
from
.functions
import
*
from
.embedding
import
*
import
threading
import
torch
as
th
import
dgl.function
as
fn
import
torch
as
th
import
torch.nn.init
as
INIT
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
,
N
):
super
(
Encoder
,
self
).
__init__
()
...
...
@@ -17,24 +19,29 @@ class Encoder(nn.Module):
self
.
layers
=
clones
(
layer
,
N
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
i
,
fields
=
'
qkv
'
):
def
pre_func
(
self
,
i
,
fields
=
"
qkv
"
):
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
0
].
norm
(
x
)
return
layer
.
self_attn
.
get
(
norm_x
,
fields
=
fields
)
return
func
def
post_func
(
self
,
i
):
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
0
].
dropout
(
o
)
x
=
layer
.
sublayer
[
1
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
{
"x"
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
func
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
layer
,
N
):
super
(
Decoder
,
self
).
__init__
()
...
...
@@ -42,30 +49,37 @@ class Decoder(nn.Module):
self
.
layers
=
clones
(
layer
,
N
)
self
.
norm
=
LayerNorm
(
layer
.
size
)
def
pre_func
(
self
,
i
,
fields
=
'
qkv
'
,
l
=
0
):
def
pre_func
(
self
,
i
,
fields
=
"
qkv
"
,
l
=
0
):
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
x
=
nodes
.
data
[
'x'
]
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
if
fields
.
startswith
(
'q'
)
else
x
if
fields
!=
'
qkv
'
:
x
=
nodes
.
data
[
"x"
]
norm_x
=
layer
.
sublayer
[
l
].
norm
(
x
)
if
fields
.
startswith
(
"q"
)
else
x
if
fields
!=
"
qkv
"
:
return
layer
.
src_attn
.
get
(
norm_x
,
fields
)
else
:
return
layer
.
self_attn
.
get
(
norm_x
,
fields
)
return
func
def
post_func
(
self
,
i
,
l
=
0
):
layer
=
self
.
layers
[
i
]
def
func
(
nodes
):
x
,
wv
,
z
=
nodes
.
data
[
'x'
],
nodes
.
data
[
'
wv
'
],
nodes
.
data
[
'z'
]
x
,
wv
,
z
=
nodes
.
data
[
"x"
],
nodes
.
data
[
"
wv
"
],
nodes
.
data
[
"z"
]
o
=
layer
.
self_attn
.
get_o
(
wv
/
z
)
x
=
x
+
layer
.
sublayer
[
l
].
dropout
(
o
)
if
l
==
1
:
x
=
layer
.
sublayer
[
2
](
x
,
layer
.
feed_forward
)
return
{
'x'
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
{
"x"
:
x
if
i
<
self
.
N
-
1
else
self
.
norm
(
x
)}
return
func
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
d_k
):
def
__init__
(
self
,
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
d_k
):
super
(
Transformer
,
self
).
__init__
()
self
.
encoder
,
self
.
decoder
=
encoder
,
decoder
self
.
src_embed
,
self
.
tgt_embed
=
src_embed
,
tgt_embed
...
...
@@ -76,11 +90,11 @@ class Transformer(nn.Module):
def
propagate_attention
(
self
,
g
,
eids
):
# Compute attention score
g
.
apply_edges
(
src_dot_dst
(
'k'
,
'q'
,
'
score
'
),
eids
)
g
.
apply_edges
(
scaled_exp
(
'
score
'
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
g
.
apply_edges
(
src_dot_dst
(
"k"
,
"q"
,
"
score
"
),
eids
)
g
.
apply_edges
(
scaled_exp
(
"
score
"
,
np
.
sqrt
(
self
.
d_k
)),
eids
)
# Send weighted values to target nodes
g
.
send_and_recv
(
eids
,
fn
.
u_mul_e
(
'v'
,
'
score
'
,
'v'
),
fn
.
sum
(
'v'
,
'
wv
'
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
'
score
'
,
'
score
'
),
fn
.
sum
(
'
score
'
,
'z'
))
g
.
send_and_recv
(
eids
,
fn
.
u_mul_e
(
"v"
,
"
score
"
,
"v"
),
fn
.
sum
(
"v"
,
"
wv
"
))
g
.
send_and_recv
(
eids
,
fn
.
copy_e
(
"
score
"
,
"
score
"
),
fn
.
sum
(
"
score
"
,
"z"
))
def
update_graph
(
self
,
g
,
eids
,
pre_pairs
,
post_pairs
):
"Update the node states and edge states of the graph."
...
...
@@ -98,27 +112,44 @@ class Transformer(nn.Module):
nids
,
eids
=
graph
.
nids
,
graph
.
eids
# embed
src_embed
,
src_pos
=
self
.
src_embed
(
graph
.
src
[
0
]),
self
.
pos_enc
(
graph
.
src
[
1
])
tgt_embed
,
tgt_pos
=
self
.
tgt_embed
(
graph
.
tgt
[
0
]),
self
.
pos_enc
(
graph
.
tgt
[
1
])
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
g
.
nodes
[
nids
[
'dec'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
src_embed
,
src_pos
=
self
.
src_embed
(
graph
.
src
[
0
]),
self
.
pos_enc
(
graph
.
src
[
1
]
)
tgt_embed
,
tgt_pos
=
self
.
tgt_embed
(
graph
.
tgt
[
0
]),
self
.
pos_enc
(
graph
.
tgt
[
1
]
)
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
g
.
nodes
[
nids
[
"dec"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
for
i
in
range
(
self
.
encoder
.
N
):
pre_func
=
self
.
encoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'enc'
],
eids
[
'ee'
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
nodes
,
edges
=
nids
[
"enc"
],
eids
[
"ee"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
for
i
in
range
(
self
.
decoder
.
N
):
pre_func
=
self
.
decoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
decoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
decoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'dec'
],
eids
[
'dd'
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
pre_q
=
self
.
decoder
.
pre_func
(
i
,
'q'
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
'kv'
,
1
)
nodes
,
edges
=
nids
[
"dec"
],
eids
[
"dd"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
pre_q
=
self
.
decoder
.
pre_func
(
i
,
"q"
,
1
)
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
"kv"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
nodes_e
,
edges
=
nids
[
'enc'
],
eids
[
'ed'
]
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
)])
nodes_e
,
edges
=
nids
[
"enc"
],
eids
[
"ed"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes
)],
)
# visualize attention
"""
...
...
@@ -126,9 +157,10 @@ class Transformer(nn.Module):
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
"""
return
self
.
generator
(
g
.
ndata
[
'x'
][
nids
[
'dec'
]])
return
self
.
generator
(
g
.
ndata
[
"x"
][
nids
[
"dec"
]])
def
infer
(
self
,
graph
,
max_len
,
eos_id
,
k
,
alpha
=
1.0
):
'''
"""
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
...
...
@@ -138,7 +170,7 @@ class Transformer(nn.Module):
k: beam size
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
"""
g
=
graph
.
g
N
,
E
=
graph
.
n_nodes
,
graph
.
n_edges
nids
,
eids
=
graph
.
nids
,
graph
.
eids
...
...
@@ -146,21 +178,25 @@ class Transformer(nn.Module):
# embed & pos
src_embed
=
self
.
src_embed
(
graph
.
src
[
0
])
src_pos
=
self
.
pos_enc
(
graph
.
src
[
1
])
g
.
nodes
[
nids
[
'enc'
]].
data
[
'pos'
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
'enc'
]].
data
[
'x'
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
g
.
nodes
[
nids
[
"enc"
]].
data
[
"pos"
]
=
graph
.
src
[
1
]
g
.
nodes
[
nids
[
"enc"
]].
data
[
"x"
]
=
self
.
pos_enc
.
dropout
(
src_embed
+
src_pos
)
tgt_pos
=
self
.
pos_enc
(
graph
.
tgt
[
1
])
g
.
nodes
[
nids
[
'
dec
'
]].
data
[
'
pos
'
]
=
graph
.
tgt
[
1
]
g
.
nodes
[
nids
[
"
dec
"
]].
data
[
"
pos
"
]
=
graph
.
tgt
[
1
]
# init mask
device
=
next
(
self
.
parameters
()).
device
g
.
ndata
[
'
mask
'
]
=
th
.
zeros
(
N
,
dtype
=
th
.
uint8
,
device
=
device
)
g
.
ndata
[
"
mask
"
]
=
th
.
zeros
(
N
,
dtype
=
th
.
uint8
,
device
=
device
)
# encode
for
i
in
range
(
self
.
encoder
.
N
):
pre_func
=
self
.
encoder
.
pre_func
(
i
,
'
qkv
'
)
pre_func
=
self
.
encoder
.
pre_func
(
i
,
"
qkv
"
)
post_func
=
self
.
encoder
.
post_func
(
i
)
nodes
,
edges
=
nids
[
'enc'
],
eids
[
'ee'
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
nodes
,
edges
=
nids
[
"enc"
],
eids
[
"ee"
]
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
# decode
log_prob
=
None
...
...
@@ -168,36 +204,76 @@ class Transformer(nn.Module):
for
step
in
range
(
1
,
max_len
):
y
=
y
.
view
(
-
1
)
tgt_embed
=
self
.
tgt_embed
(
y
)
g
.
ndata
[
'x'
][
nids
[
'dec'
]]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
edges_ed
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
'pos'
]
<
step
)
&
~
e
.
dst
[
'mask'
].
bool
(),
eids
[
'ed'
])
edges_dd
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
'pos'
]
<
step
)
&
~
e
.
dst
[
'mask'
].
bool
(),
eids
[
'dd'
])
nodes_d
=
g
.
filter_nodes
(
lambda
v
:
(
v
.
data
[
'pos'
]
<
step
)
&
~
v
.
data
[
'mask'
].
bool
(),
nids
[
'dec'
])
g
.
ndata
[
"x"
][
nids
[
"dec"
]]
=
self
.
pos_enc
.
dropout
(
tgt_embed
+
tgt_pos
)
edges_ed
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
"pos"
]
<
step
)
&
~
e
.
dst
[
"mask"
].
bool
(),
eids
[
"ed"
],
)
edges_dd
=
g
.
filter_edges
(
lambda
e
:
(
e
.
dst
[
"pos"
]
<
step
)
&
~
e
.
dst
[
"mask"
].
bool
(),
eids
[
"dd"
],
)
nodes_d
=
g
.
filter_nodes
(
lambda
v
:
(
v
.
data
[
"pos"
]
<
step
)
&
~
v
.
data
[
"mask"
].
bool
(),
nids
[
"dec"
],
)
for
i
in
range
(
self
.
decoder
.
N
):
pre_func
,
post_func
=
self
.
decoder
.
pre_func
(
i
,
'qkv'
),
self
.
decoder
.
post_func
(
i
)
pre_func
,
post_func
=
self
.
decoder
.
pre_func
(
i
,
"qkv"
),
self
.
decoder
.
post_func
(
i
)
nodes
,
edges
=
nodes_d
,
edges_dd
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)])
pre_q
,
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
'q'
,
1
),
self
.
decoder
.
pre_func
(
i
,
'kv'
,
1
)
self
.
update_graph
(
g
,
edges
,
[(
pre_func
,
nodes
)],
[(
post_func
,
nodes
)]
)
pre_q
,
pre_kv
=
self
.
decoder
.
pre_func
(
i
,
"q"
,
1
),
self
.
decoder
.
pre_func
(
i
,
"kv"
,
1
)
post_func
=
self
.
decoder
.
post_func
(
i
,
1
)
nodes_e
,
nodes_d
,
edges
=
nids
[
'enc'
],
nodes_d
,
edges_ed
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes_d
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes_d
)])
nodes_e
,
nodes_d
,
edges
=
nids
[
"enc"
],
nodes_d
,
edges_ed
self
.
update_graph
(
g
,
edges
,
[(
pre_q
,
nodes_d
),
(
pre_kv
,
nodes_e
)],
[(
post_func
,
nodes_d
)],
)
frontiers
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
'pos'
]
==
step
-
1
,
nids
[
'dec'
])
out
=
self
.
generator
(
g
.
ndata
[
'x'
][
frontiers
])
frontiers
=
g
.
filter_nodes
(
lambda
v
:
v
.
data
[
"pos"
]
==
step
-
1
,
nids
[
"dec"
]
)
out
=
self
.
generator
(
g
.
ndata
[
"x"
][
frontiers
])
batch_size
=
frontiers
.
shape
[
0
]
//
k
vocab_size
=
out
.
shape
[
-
1
]
# Mask output for complete sequence
one_hot
=
th
.
zeros
(
vocab_size
).
fill_
(
-
1e9
).
to
(
device
)
one_hot
[
eos_id
]
=
0
mask
=
g
.
ndata
[
'
mask
'
][
frontiers
].
unsqueeze
(
-
1
).
float
()
mask
=
g
.
ndata
[
"
mask
"
][
frontiers
].
unsqueeze
(
-
1
).
float
()
out
=
out
*
(
1
-
mask
)
+
one_hot
.
unsqueeze
(
0
)
*
mask
if
log_prob
is
None
:
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
log_prob
,
pos
=
out
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
topk
(
k
,
dim
=-
1
)
eos
=
th
.
zeros
(
batch_size
,
k
).
byte
()
else
:
norm_old
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
4.
+
step
)
/
6
,
alpha
)
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
))
*
np
.
power
((
5.
+
step
)
/
6
,
alpha
)
log_prob
,
pos
=
((
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
))
/
norm_new
.
unsqueeze
(
-
1
)).
view
(
batch_size
,
-
1
).
topk
(
k
,
dim
=-
1
)
norm_old
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
)
)
*
np
.
power
((
4.0
+
step
)
/
6
,
alpha
)
norm_new
=
eos
.
float
().
to
(
device
)
+
(
1
-
eos
.
float
().
to
(
device
)
)
*
np
.
power
((
5.0
+
step
)
/
6
,
alpha
)
log_prob
,
pos
=
(
(
(
out
.
view
(
batch_size
,
k
,
-
1
)
+
(
log_prob
*
norm_old
).
unsqueeze
(
-
1
)
)
/
norm_new
.
unsqueeze
(
-
1
)
)
.
view
(
batch_size
,
-
1
)
.
topk
(
k
,
dim
=-
1
)
)
_y
=
y
.
view
(
batch_size
*
k
,
-
1
)
y
=
th
.
zeros_like
(
_y
)
...
...
@@ -206,14 +282,16 @@ class Transformer(nn.Module):
for
j
in
range
(
k
):
_j
=
pos
[
i
,
j
].
item
()
//
vocab_size
token
=
pos
[
i
,
j
].
item
()
%
vocab_size
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
y
[
i
*
k
+
j
,
:]
=
_y
[
i
*
k
+
_j
,
:]
y
[
i
*
k
+
j
,
step
]
=
token
eos
[
i
,
j
]
=
_eos
[
i
,
_j
]
|
(
token
==
eos_id
)
if
eos
.
all
():
break
else
:
g
.
ndata
[
'mask'
][
nids
[
'dec'
]]
=
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
g
.
ndata
[
"mask"
][
nids
[
"dec"
]]
=
(
eos
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
max_len
).
view
(
-
1
).
to
(
device
)
)
return
y
.
view
(
batch_size
,
k
,
-
1
)[:,
0
,
:].
tolist
()
def
_register_att_map
(
self
,
g
,
enc_ids
,
dec_ids
):
...
...
@@ -224,22 +302,42 @@ class Transformer(nn.Module):
]
def
make_model
(
src_vocab
,
tgt_vocab
,
N
=
6
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
,
universal
=
False
):
def
make_model
(
src_vocab
,
tgt_vocab
,
N
=
6
,
dim_model
=
512
,
dim_ff
=
2048
,
h
=
8
,
dropout
=
0.1
,
universal
=
False
,
):
if
universal
:
return
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
,
dim_ff
,
h
,
dropout
)
return
make_universal_model
(
src_vocab
,
tgt_vocab
,
dim_model
,
dim_ff
,
h
,
dropout
)
c
=
copy
.
deepcopy
attn
=
MultiHeadAttention
(
h
,
dim_model
)
ff
=
PositionwiseFeedForward
(
dim_model
,
dim_ff
)
pos_enc
=
PositionalEncoding
(
dim_model
,
dropout
)
encoder
=
Encoder
(
EncoderLayer
(
dim_model
,
c
(
attn
),
c
(
ff
),
dropout
),
N
)
decoder
=
Decoder
(
DecoderLayer
(
dim_model
,
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
),
N
)
decoder
=
Decoder
(
DecoderLayer
(
dim_model
,
c
(
attn
),
c
(
attn
),
c
(
ff
),
dropout
),
N
)
src_embed
=
Embeddings
(
src_vocab
,
dim_model
)
tgt_embed
=
Embeddings
(
tgt_vocab
,
dim_model
)
generator
=
Generator
(
dim_model
,
tgt_vocab
)
model
=
Transformer
(
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
dim_model
//
h
)
encoder
,
decoder
,
src_embed
,
tgt_embed
,
pos_enc
,
generator
,
h
,
dim_model
//
h
,
)
# xavier init
for
p
in
model
.
parameters
():
if
p
.
dim
()
>
1
:
...
...
examples/pytorch/transformer/modules/viz.py
View file @
704bcaf6
import
os
import
numpy
as
np
import
torch
as
th
import
networkx
as
nx
import
matplotlib
as
mpl
import
matplotlib.pyplot
as
plt
import
matplotlib.animation
as
animation
import
matplotlib.pyplot
as
plt
import
networkx
as
nx
import
numpy
as
np
import
torch
as
th
from
networkx.algorithms
import
bipartite
def
get_attention_map
(
g
,
src_nodes
,
dst_nodes
,
h
):
"""
To visualize the attention score between two set of nodes.
...
...
@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
if
not
g
.
has_edge_between
(
src
,
dst
):
continue
eid
=
g
.
edge_ids
(
src
,
dst
)
weight
[
i
][
j
]
=
g
.
edata
[
'
score
'
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
[
i
][
j
]
=
g
.
edata
[
"
score
"
][
eid
].
squeeze
(
-
1
).
cpu
().
detach
()
weight
=
weight
.
transpose
(
0
,
2
)
att
=
th
.
softmax
(
weight
,
-
2
)
return
att
.
numpy
()
def
draw_heatmap
(
array
,
input_seq
,
output_seq
,
dirname
,
name
):
dirname
=
os
.
path
.
join
(
'
log
'
,
dirname
)
dirname
=
os
.
path
.
join
(
"
log
"
,
dirname
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
...
...
@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name):
axes
[
i
,
j
].
set_xticks
(
np
.
arange
(
len
(
output_seq
)))
axes
[
i
,
j
].
set_yticklabels
(
input_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_xticklabels
(
output_seq
,
fontsize
=
4
)
axes
[
i
,
j
].
set_title
(
'head_{}'
.
format
(
cnt
),
fontsize
=
10
)
plt
.
setp
(
axes
[
i
,
j
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
rotation_mode
=
"anchor"
)
axes
[
i
,
j
].
set_title
(
"head_{}"
.
format
(
cnt
),
fontsize
=
10
)
plt
.
setp
(
axes
[
i
,
j
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
rotation_mode
=
"anchor"
,
)
cnt
+=
1
fig
.
suptitle
(
name
,
fontsize
=
12
)
plt
.
tight_layout
()
plt
.
savefig
(
os
.
path
.
join
(
dirname
,
'
{}.pdf
'
.
format
(
name
)))
plt
.
savefig
(
os
.
path
.
join
(
dirname
,
"
{}.pdf
"
.
format
(
name
)))
plt
.
close
()
def
draw_atts
(
maps
,
src
,
tgt
,
dirname
,
prefix
):
'''
"""
maps[0]: encoder self-attention
maps[1]: encoder-decoder attention
maps[2]: decoder self-attention
'''
draw_heatmap
(
maps
[
0
],
src
,
src
,
dirname
,
'
{}_enc_self_attn
'
.
format
(
prefix
))
draw_heatmap
(
maps
[
1
],
src
,
tgt
,
dirname
,
'
{}_enc_dec_attn
'
.
format
(
prefix
))
draw_heatmap
(
maps
[
2
],
tgt
,
tgt
,
dirname
,
'
{}_dec_self_attn
'
.
format
(
prefix
))
"""
draw_heatmap
(
maps
[
0
],
src
,
src
,
dirname
,
"
{}_enc_self_attn
"
.
format
(
prefix
))
draw_heatmap
(
maps
[
1
],
src
,
tgt
,
dirname
,
"
{}_enc_dec_attn
"
.
format
(
prefix
))
draw_heatmap
(
maps
[
2
],
tgt
,
tgt
,
dirname
,
"
{}_dec_self_attn
"
.
format
(
prefix
))
mode2id
=
{
'e2e'
:
0
,
'e2d'
:
1
,
'd2d'
:
2
}
mode2id
=
{
"e2e"
:
0
,
"e2d"
:
1
,
"d2d"
:
2
}
colorbar
=
None
def
att_animation
(
maps_array
,
mode
,
src
,
tgt
,
head_id
):
weights
=
[
maps
[
mode2id
[
mode
]][
head_id
]
for
maps
in
maps_array
]
fig
,
axes
=
plt
.
subplots
(
1
,
2
)
...
...
@@ -71,63 +81,112 @@ def att_animation(maps_array, mode, src, tgt, head_id):
if
colorbar
:
colorbar
.
remove
()
plt
.
cla
()
axes
[
0
].
set_title
(
'
heatmap
'
)
axes
[
0
].
set_title
(
"
heatmap
"
)
axes
[
0
].
set_yticks
(
np
.
arange
(
len
(
src
)))
axes
[
0
].
set_xticks
(
np
.
arange
(
len
(
tgt
)))
axes
[
0
].
set_yticklabels
(
src
)
axes
[
0
].
set_xticklabels
(
tgt
)
plt
.
setp
(
axes
[
0
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
rotation_mode
=
"anchor"
)
plt
.
setp
(
axes
[
0
].
get_xticklabels
(),
rotation
=
45
,
ha
=
"right"
,
rotation_mode
=
"anchor"
,
)
fig
.
suptitle
(
'
epoch {}
'
.
format
(
i
))
fig
.
suptitle
(
"
epoch {}
"
.
format
(
i
))
weight
=
weights
[
i
].
transpose
(
-
1
,
-
2
)
heatmap
=
axes
[
0
].
pcolor
(
weight
,
vmin
=
0
,
vmax
=
1
,
cmap
=
plt
.
cm
.
Blues
)
colorbar
=
plt
.
colorbar
(
heatmap
,
ax
=
axes
[
0
],
fraction
=
0.046
,
pad
=
0.04
)
axes
[
0
].
set_aspect
(
'
equal
'
)
axes
[
0
].
set_aspect
(
"
equal
"
)
axes
[
1
].
axis
(
"off"
)
graph_att_head
(
src
,
tgt
,
weight
,
axes
[
1
],
'graph'
)
ani
=
animation
.
FuncAnimation
(
fig
,
weight_animate
,
frames
=
len
(
weights
),
interval
=
500
,
repeat_delay
=
2000
)
graph_att_head
(
src
,
tgt
,
weight
,
axes
[
1
],
"graph"
)
ani
=
animation
.
FuncAnimation
(
fig
,
weight_animate
,
frames
=
len
(
weights
),
interval
=
500
,
repeat_delay
=
2000
,
)
return
ani
def
graph_att_head
(
M
,
N
,
weight
,
ax
,
title
):
"credit: Jinjing Zhou"
in_nodes
=
len
(
M
)
out_nodes
=
len
(
N
)
in_nodes
=
len
(
M
)
out_nodes
=
len
(
N
)
g
=
nx
.
bipartite
.
generators
.
complete_bipartite_graph
(
in_nodes
,
out_nodes
)
g
=
nx
.
bipartite
.
generators
.
complete_bipartite_graph
(
in_nodes
,
out_nodes
)
X
,
Y
=
bipartite
.
sets
(
g
)
height_in
=
10
height_out
=
height_in
height_in_y
=
np
.
linspace
(
0
,
height_in
,
in_nodes
)
height_out_y
=
np
.
linspace
((
height_in
-
height_out
)
/
2
,
height_out
,
out_nodes
)
height_out_y
=
np
.
linspace
(
(
height_in
-
height_out
)
/
2
,
height_out
,
out_nodes
)
pos
=
dict
()
pos
.
update
((
n
,
(
1
,
i
))
for
i
,
n
in
zip
(
height_in_y
,
X
))
# put nodes from X at x=1
pos
.
update
((
n
,
(
3
,
i
))
for
i
,
n
in
zip
(
height_out_y
,
Y
))
# put nodes from Y at x=2
ax
.
axis
(
'off'
)
ax
.
set_xlim
(
-
1
,
4
)
pos
.
update
(
(
n
,
(
1
,
i
))
for
i
,
n
in
zip
(
height_in_y
,
X
)
)
# put nodes from X at x=1
pos
.
update
(
(
n
,
(
3
,
i
))
for
i
,
n
in
zip
(
height_out_y
,
Y
)
)
# put nodes from Y at x=2
ax
.
axis
(
"off"
)
ax
.
set_xlim
(
-
1
,
4
)
ax
.
set_title
(
title
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
),
node_color
=
'r'
,
node_size
=
50
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
,
in_nodes
+
out_nodes
),
node_color
=
'b'
,
node_size
=
50
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
),
node_color
=
"r"
,
node_size
=
50
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
pos
,
nodelist
=
range
(
in_nodes
,
in_nodes
+
out_nodes
),
node_color
=
"b"
,
node_size
=
50
,
ax
=
ax
,
)
for
edge
in
g
.
edges
():
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
[
edge
],
width
=
weight
[
edge
[
0
],
edge
[
1
]
-
in_nodes
]
*
1.5
,
ax
=
ax
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
:
label
+
' '
for
i
,
label
in
enumerate
(
M
)},
horizontalalignment
=
'right'
,
font_size
=
8
,
ax
=
ax
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
+
in_nodes
:
' '
+
label
for
i
,
label
in
enumerate
(
N
)},
horizontalalignment
=
'left'
,
font_size
=
8
,
ax
=
ax
)
nx
.
draw_networkx_edges
(
g
,
pos
,
edgelist
=
[
edge
],
width
=
weight
[
edge
[
0
],
edge
[
1
]
-
in_nodes
]
*
1.5
,
ax
=
ax
,
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
:
label
+
" "
for
i
,
label
in
enumerate
(
M
)},
horizontalalignment
=
"right"
,
font_size
=
8
,
ax
=
ax
,
)
nx
.
draw_networkx_labels
(
g
,
pos
,
{
i
+
in_nodes
:
" "
+
label
for
i
,
label
in
enumerate
(
N
)},
horizontalalignment
=
"left"
,
font_size
=
8
,
ax
=
ax
,
)
import
networkx
as
nx
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
from
networkx.utils
import
is_string_like
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
"The following function was modified from the source code of networkx"
def
draw_networkx_edges
(
G
,
pos
,
def
draw_networkx_edges
(
G
,
pos
,
edgelist
=
None
,
width
=
1.0
,
edge_color
=
'k'
,
style
=
'
solid
'
,
edge_color
=
"k"
,
style
=
"
solid
"
,
alpha
=
1.0
,
arrowstyle
=
'
-|>
'
,
arrowstyle
=
"
-|>
"
,
arrowsize
=
10
,
edge_cmap
=
None
,
edge_vmin
=
None
,
...
...
@@ -138,8 +197,9 @@ def draw_networkx_edges(G, pos,
node_size
=
300
,
nodelist
=
None
,
node_shape
=
"o"
,
connectionstyle
=
'arc3'
,
**
kwds
):
connectionstyle
=
"arc3"
,
**
kwds
):
"""Draw the edges of the graph G.
This draws only the edges of the graph G.
...
...
@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos,
"""
try
:
import
matplotlib
import
matplotlib.pyplot
as
plt
import
matplotlib.cbook
as
cb
from
matplotlib.colors
import
colorConverter
,
Colormap
,
Normalize
from
matplotlib.collections
import
LineCollection
from
matplotlib.patches
import
FancyArrowPatch
,
ConnectionStyle
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
matplotlib.collections
import
LineCollection
from
matplotlib.colors
import
colorConverter
,
Colormap
,
Normalize
from
matplotlib.patches
import
ConnectionStyle
,
FancyArrowPatch
except
ImportError
:
raise
ImportError
(
"Matplotlib required for draw()"
)
except
RuntimeError
:
...
...
@@ -270,33 +330,38 @@ def draw_networkx_edges(G, pos,
else
:
lw
=
width
if
not
is_string_like
(
edge_color
)
\
and
cb
.
iterable
(
edge_color
)
\
and
len
(
edge_color
)
==
len
(
edge_pos
):
if
(
not
is_string_like
(
edge_color
)
and
cb
.
iterable
(
edge_color
)
and
len
(
edge_color
)
==
len
(
edge_pos
)
):
if
np
.
alltrue
([
is_string_like
(
c
)
for
c
in
edge_color
]):
# (should check ALL elements)
# list of color letters such as ['k','r','k',...]
edge_colors
=
tuple
([
colorConverter
.
to_rgba
(
c
,
alpha
)
for
c
in
edge_color
])
edge_colors
=
tuple
(
[
colorConverter
.
to_rgba
(
c
,
alpha
)
for
c
in
edge_color
]
)
elif
np
.
alltrue
([
not
is_string_like
(
c
)
for
c
in
edge_color
]):
# If color specs are given as (rgb) or (rgba) tuples, we're OK
if
np
.
alltrue
([
cb
.
iterable
(
c
)
and
len
(
c
)
in
(
3
,
4
)
for
c
in
edge_color
]):
if
np
.
alltrue
(
[
cb
.
iterable
(
c
)
and
len
(
c
)
in
(
3
,
4
)
for
c
in
edge_color
]
):
edge_colors
=
tuple
(
edge_color
)
else
:
# numbers (which are going to be mapped with a colormap)
edge_colors
=
None
else
:
raise
ValueError
(
'
edge_color must contain color names or numbers
'
)
raise
ValueError
(
"
edge_color must contain color names or numbers
"
)
else
:
if
is_string_like
(
edge_color
)
or
len
(
edge_color
)
==
1
:
edge_colors
=
(
colorConverter
.
to_rgba
(
edge_color
,
alpha
),
)
edge_colors
=
(
colorConverter
.
to_rgba
(
edge_color
,
alpha
),)
else
:
msg
=
'
edge_color must be a color or list of one color per edge
'
msg
=
"
edge_color must be a color or list of one color per edge
"
raise
ValueError
(
msg
)
if
(
not
G
.
is_directed
()
or
not
arrows
):
edge_collection
=
LineCollection
(
edge_pos
,
if
not
G
.
is_directed
()
or
not
arrows
:
edge_collection
=
LineCollection
(
edge_pos
,
colors
=
edge_colors
,
linewidths
=
lw
,
antialiaseds
=
(
1
,),
...
...
@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos,
if
edge_colors
is
None
:
if
edge_cmap
is
not
None
:
assert
(
isinstance
(
edge_cmap
,
Colormap
)
)
assert
isinstance
(
edge_cmap
,
Colormap
)
edge_collection
.
set_array
(
np
.
asarray
(
edge_color
))
edge_collection
.
set_cmap
(
edge_cmap
)
if
edge_vmin
is
not
None
or
edge_vmax
is
not
None
:
...
...
@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos,
arrow_colors
=
edge_colors
if
arrow_colors
is
None
:
if
edge_cmap
is
not
None
:
assert
(
isinstance
(
edge_cmap
,
Colormap
)
)
assert
isinstance
(
edge_cmap
,
Colormap
)
else
:
edge_cmap
=
plt
.
get_cmap
()
# default matplotlib colormap
if
edge_vmin
is
None
:
...
...
@@ -379,7 +444,9 @@ def draw_networkx_edges(G, pos,
line_width
=
lw
[
i
]
else
:
line_width
=
lw
[
0
]
arrow
=
FancyArrowPatch
((
x1
,
y1
),
(
x2
,
y2
),
arrow
=
FancyArrowPatch
(
(
x1
,
y1
),
(
x2
,
y2
),
arrowstyle
=
arrowstyle
,
shrinkA
=
shrink_source
,
shrinkB
=
shrink_target
,
...
...
@@ -387,7 +454,8 @@ def draw_networkx_edges(G, pos,
connectionstyle
=
connectionstyle
,
color
=
arrow_color
,
linewidth
=
line_width
,
zorder
=
1
)
# arrows go behind nodes
zorder
=
1
,
)
# arrows go behind nodes
# There seems to be a bug in matplotlib to make collections of
# FancyArrowPatch instances. Until fixed, the patches are added
...
...
@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos,
def
draw_g
(
graph
):
g
=
graph
.
g
.
to_networkx
()
fig
=
plt
.
figure
(
figsize
=
(
8
,
4
),
dpi
=
150
)
ax
=
fig
.
subplots
()
ax
.
axis
(
'
off
'
)
ax
.
set_ylim
(
-
1
,
1.5
)
en_indx
=
graph
.
nids
[
'
enc
'
].
tolist
()
de_indx
=
graph
.
nids
[
'
dec
'
].
tolist
()
en_l
=
{
i
:
np
.
array
([
i
,
0
])
for
i
in
en_indx
}
de_l
=
{
i
:
np
.
array
([
i
+
2
,
1
])
for
i
in
de_indx
}
en_de_s
=
[]
g
=
graph
.
g
.
to_networkx
()
fig
=
plt
.
figure
(
figsize
=
(
8
,
4
),
dpi
=
150
)
ax
=
fig
.
subplots
()
ax
.
axis
(
"
off
"
)
ax
.
set_ylim
(
-
1
,
1.5
)
en_indx
=
graph
.
nids
[
"
enc
"
].
tolist
()
de_indx
=
graph
.
nids
[
"
dec
"
].
tolist
()
en_l
=
{
i
:
np
.
array
([
i
,
0
])
for
i
in
en_indx
}
de_l
=
{
i
:
np
.
array
([
i
+
2
,
1
])
for
i
in
de_indx
}
en_de_s
=
[]
for
i
in
en_indx
:
for
j
in
de_indx
:
en_de_s
.
append
((
i
,
j
))
g
.
add_edge
(
i
,
j
)
en_s
=
[]
en_de_s
.
append
((
i
,
j
))
g
.
add_edge
(
i
,
j
)
en_s
=
[]
for
i
in
en_indx
:
for
j
in
en_indx
:
g
.
add_edge
(
i
,
j
)
en_s
.
append
((
i
,
j
))
g
.
add_edge
(
i
,
j
)
en_s
.
append
((
i
,
j
))
de_s
=
[]
for
idx
,
i
in
enumerate
(
de_indx
):
de_s
=
[]
for
idx
,
i
in
enumerate
(
de_indx
):
for
j
in
de_indx
[
idx
:]:
g
.
add_edge
(
i
,
j
)
de_s
.
append
((
i
,
j
))
g
.
add_edge
(
i
,
j
)
de_s
.
append
((
i
,
j
))
nx
.
draw_networkx_nodes
(
g
,
en_l
,
nodelist
=
en_indx
,
node_color
=
'r'
,
node_size
=
60
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
de_l
,
nodelist
=
de_indx
,
node_color
=
'r'
,
node_size
=
60
,
ax
=
ax
)
draw_networkx_edges
(
g
,
en_l
,
edgelist
=
en_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
)
draw_networkx_edges
(
g
,
de_l
,
edgelist
=
de_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
)
draw_networkx_edges
(
g
,{
**
en_l
,
**
de_l
},
edgelist
=
en_de_s
,
width
=
0.3
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
en_l
,
nodelist
=
en_indx
,
node_color
=
"r"
,
node_size
=
60
,
ax
=
ax
)
nx
.
draw_networkx_nodes
(
g
,
de_l
,
nodelist
=
de_indx
,
node_color
=
"r"
,
node_size
=
60
,
ax
=
ax
)
draw_networkx_edges
(
g
,
en_l
,
edgelist
=
en_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
,
)
draw_networkx_edges
(
g
,
de_l
,
edgelist
=
de_s
,
ax
=
ax
,
connectionstyle
=
"arc3,rad=-0.3"
,
width
=
0.5
,
)
draw_networkx_edges
(
g
,
{
**
en_l
,
**
de_l
},
edgelist
=
en_de_s
,
width
=
0.3
,
ax
=
ax
)
# ax.add_patch()
ax
.
text
(
len
(
en_indx
)
+
0.5
,
0
,
"Encoder"
,
verticalalignment
=
'center'
,
horizontalalignment
=
'left'
)
ax
.
text
(
len
(
en_indx
)
+
0.5
,
0
,
"Encoder"
,
verticalalignment
=
"center"
,
horizontalalignment
=
"left"
,
)
ax
.
text
(
len
(
en_indx
)
+
0.5
,
1
,
"Decoder"
,
verticalalignment
=
'center'
,
horizontalalignment
=
'right'
)
delta
=
0.03
for
value
in
{
**
en_l
,
**
de_l
}.
values
():
x
,
y
=
value
ax
.
add_patch
(
FancyArrowPatch
((
x
-
delta
,
y
+
delta
),(
x
-
delta
,
y
-
delta
),
arrowstyle
=
"->"
,
mutation_scale
=
8
,
connectionstyle
=
"arc3,rad=3"
))
ax
.
text
(
len
(
en_indx
)
+
0.5
,
1
,
"Decoder"
,
verticalalignment
=
"center"
,
horizontalalignment
=
"right"
,
)
delta
=
0.03
for
value
in
{
**
en_l
,
**
de_l
}.
values
():
x
,
y
=
value
ax
.
add_patch
(
FancyArrowPatch
(
(
x
-
delta
,
y
+
delta
),
(
x
-
delta
,
y
-
delta
),
arrowstyle
=
"->"
,
mutation_scale
=
8
,
connectionstyle
=
"arc3,rad=3"
,
)
)
plt
.
show
(
fig
)
examples/pytorch/tree_lstm/train.py
View file @
704bcaf6
...
...
@@ -2,17 +2,17 @@ import argparse
import
collections
import
time
import
dgl
import
numpy
as
np
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.nn.init
as
INIT
import
torch.optim
as
optim
from
dgl.data.tree
import
SSTDataset
from
torch.utils.data
import
DataLoader
from
tree_lstm
import
TreeLSTM
import
dgl
from
dgl.data.tree
import
SSTDataset
SSTBatch
=
collections
.
namedtuple
(
"SSTBatch"
,
[
"graph"
,
"mask"
,
"wordid"
,
"label"
]
)
...
...
examples/pytorch/tree_lstm/tree_lstm.py
View file @
704bcaf6
...
...
@@ -2,14 +2,16 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import
time
import
itertools
import
time
import
dgl
import
networkx
as
nx
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
class
TreeLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
x_size
,
h_size
):
...
...
@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module):
self
.
U_f
=
nn
.
Linear
(
2
*
h_size
,
2
*
h_size
)
def
message_func
(
self
,
edges
):
return
{
'h'
:
edges
.
src
[
'h'
],
'c'
:
edges
.
src
[
'c'
]}
return
{
"h"
:
edges
.
src
[
"h"
],
"c"
:
edges
.
src
[
"c"
]}
def
reduce_func
(
self
,
nodes
):
h_cat
=
nodes
.
mailbox
[
'h'
].
view
(
nodes
.
mailbox
[
'h'
].
size
(
0
),
-
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
h_cat
)).
view
(
*
nodes
.
mailbox
[
'h'
].
size
())
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
'c'
],
1
)
return
{
'
iou
'
:
self
.
U_iou
(
h_cat
),
'c'
:
c
}
h_cat
=
nodes
.
mailbox
[
"h"
].
view
(
nodes
.
mailbox
[
"h"
].
size
(
0
),
-
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
h_cat
)).
view
(
*
nodes
.
mailbox
[
"h"
].
size
())
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
"c"
],
1
)
return
{
"
iou
"
:
self
.
U_iou
(
h_cat
),
"c"
:
c
}
def
apply_node_func
(
self
,
nodes
):
iou
=
nodes
.
data
[
'
iou
'
]
+
self
.
b_iou
iou
=
nodes
.
data
[
"
iou
"
]
+
self
.
b_iou
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
c
=
i
*
u
+
nodes
.
data
[
'c'
]
c
=
i
*
u
+
nodes
.
data
[
"c"
]
h
=
o
*
th
.
tanh
(
c
)
return
{
'h'
:
h
,
'c'
:
c
}
return
{
"h"
:
h
,
"c"
:
c
}
class
ChildSumTreeLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
x_size
,
h_size
):
...
...
@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module):
self
.
U_f
=
nn
.
Linear
(
h_size
,
h_size
)
def
message_func
(
self
,
edges
):
return
{
'h'
:
edges
.
src
[
'h'
],
'c'
:
edges
.
src
[
'c'
]}
return
{
"h"
:
edges
.
src
[
"h"
],
"c"
:
edges
.
src
[
"c"
]}
def
reduce_func
(
self
,
nodes
):
h_tild
=
th
.
sum
(
nodes
.
mailbox
[
'h'
],
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
nodes
.
mailbox
[
'h'
]))
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
'c'
],
1
)
return
{
'
iou
'
:
self
.
U_iou
(
h_tild
),
'c'
:
c
}
h_tild
=
th
.
sum
(
nodes
.
mailbox
[
"h"
],
1
)
f
=
th
.
sigmoid
(
self
.
U_f
(
nodes
.
mailbox
[
"h"
]))
c
=
th
.
sum
(
f
*
nodes
.
mailbox
[
"c"
],
1
)
return
{
"
iou
"
:
self
.
U_iou
(
h_tild
),
"c"
:
c
}
def
apply_node_func
(
self
,
nodes
):
iou
=
nodes
.
data
[
'
iou
'
]
+
self
.
b_iou
iou
=
nodes
.
data
[
"
iou
"
]
+
self
.
b_iou
i
,
o
,
u
=
th
.
chunk
(
iou
,
3
,
1
)
i
,
o
,
u
=
th
.
sigmoid
(
i
),
th
.
sigmoid
(
o
),
th
.
tanh
(
u
)
c
=
i
*
u
+
nodes
.
data
[
'c'
]
c
=
i
*
u
+
nodes
.
data
[
"c"
]
h
=
o
*
th
.
tanh
(
c
)
return
{
'h'
:
h
,
'c'
:
c
}
return
{
"h"
:
h
,
"c"
:
c
}
class
TreeLSTM
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
num_vocabs
,
x_size
,
h_size
,
num_classes
,
dropout
,
cell_type
=
'nary'
,
pretrained_emb
=
None
):
cell_type
=
"nary"
,
pretrained_emb
=
None
,
):
super
(
TreeLSTM
,
self
).
__init__
()
self
.
x_size
=
x_size
self
.
embedding
=
nn
.
Embedding
(
num_vocabs
,
x_size
)
if
pretrained_emb
is
not
None
:
print
(
'
Using glove
'
)
print
(
"
Using glove
"
)
self
.
embedding
.
weight
.
data
.
copy_
(
pretrained_emb
)
self
.
embedding
.
weight
.
requires_grad
=
True
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
linear
=
nn
.
Linear
(
h_size
,
num_classes
)
cell
=
TreeLSTMCell
if
cell_type
==
'
nary
'
else
ChildSumTreeLSTMCell
cell
=
TreeLSTMCell
if
cell_type
==
"
nary
"
else
ChildSumTreeLSTMCell
self
.
cell
=
cell
(
x_size
,
h_size
)
def
forward
(
self
,
batch
,
g
,
h
,
c
):
...
...
@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module):
"""
# feed embedding
embeds
=
self
.
embedding
(
batch
.
wordid
*
batch
.
mask
)
g
.
ndata
[
'iou'
]
=
self
.
cell
.
W_iou
(
self
.
dropout
(
embeds
))
*
batch
.
mask
.
float
().
unsqueeze
(
-
1
)
g
.
ndata
[
'h'
]
=
h
g
.
ndata
[
'c'
]
=
c
g
.
ndata
[
"iou"
]
=
self
.
cell
.
W_iou
(
self
.
dropout
(
embeds
)
)
*
batch
.
mask
.
float
().
unsqueeze
(
-
1
)
g
.
ndata
[
"h"
]
=
h
g
.
ndata
[
"c"
]
=
c
# propagate
dgl
.
prop_nodes_topo
(
g
,
self
.
cell
.
message_func
,
self
.
cell
.
reduce_func
,
apply_node_func
=
self
.
cell
.
apply_node_func
)
dgl
.
prop_nodes_topo
(
g
,
self
.
cell
.
message_func
,
self
.
cell
.
reduce_func
,
apply_node_func
=
self
.
cell
.
apply_node_func
,
)
# compute logits
h
=
self
.
dropout
(
g
.
ndata
.
pop
(
'h'
))
h
=
self
.
dropout
(
g
.
ndata
.
pop
(
"h"
))
logits
=
self
.
linear
(
h
)
return
logits
examples/pytorch/vgae/model.py
View file @
704bcaf6
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
train
import
device
from
dgl.nn.pytorch
import
GraphConv
from
train
import
device
class
VGAEModel
(
nn
.
Module
):
...
...
examples/pytorch/vgae/train.py
View file @
704bcaf6
...
...
@@ -2,11 +2,14 @@ import argparse
import
os
import
time
import
dgl
import
model
import
numpy
as
np
import
scipy.sparse
as
sp
import
torch
import
torch.nn.functional
as
F
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
from
input_data
import
load_data
from
preprocess
import
(
mask_test_edges
,
...
...
@@ -16,9 +19,6 @@ from preprocess import (
)
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
import
dgl
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
,
PubmedGraphDataset
os
.
environ
[
"KMP_DUPLICATE_LIB_OK"
]
=
"True"
parser
=
argparse
.
ArgumentParser
(
description
=
"Variant Graph Auto Encoder"
)
...
...
examples/pytorch/vrgcn/train_cv.py
View file @
704bcaf6
import
argparse
import
time
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
tqdm
from
torch.utils.data
import
DataLoader
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
from
dgl.data
import
RedditDataset
from
torch.utils.data
import
DataLoader
class
SAGEConvWithCV
(
nn
.
Module
):
...
...
examples/pytorch/vrgcn/train_cv_multi_gpu.py
View file @
704bcaf6
...
...
@@ -3,6 +3,10 @@ import math
import
time
import
traceback
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
torch
as
th
import
torch.multiprocessing
as
mp
...
...
@@ -10,14 +14,10 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
tqdm
from
dgl.data
import
RedditDataset
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.utils.data
import
DataLoader
import
dgl
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
from
dgl.data
import
RedditDataset
class
SAGEConvWithCV
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
activation
):
...
...
examples/sparse/c_and_s.py
View file @
704bcaf6
...
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
dgl.data
import
CoraGraphDataset
from
torch.optim
import
Adam
###############################################################################
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
###############################################################################
...
...
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment