Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
91b73823
Unverified
Commit
91b73823
authored
Feb 13, 2019
by
Minjie Wang
Committed by
GitHub
Feb 13, 2019
Browse files
[Model] update gat (#390)
* update gat: add minus max for softmax * small fix
parent
6c3dba86
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
examples/pytorch/gat/train.py
examples/pytorch/gat/train.py
+20
-16
No files found.
examples/pytorch/gat/train.py
View file @
91b73823
...
@@ -39,11 +39,11 @@ class GraphAttention(nn.Module):
...
@@ -39,11 +39,11 @@ class GraphAttention(nn.Module):
if
feat_drop
:
if
feat_drop
:
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
else
:
else
:
self
.
feat_drop
=
None
self
.
feat_drop
=
lambda
x
:
x
if
attn_drop
:
if
attn_drop
:
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
else
:
else
:
self
.
attn_drop
=
None
self
.
attn_drop
=
lambda
x
:
x
self
.
attn_l
=
nn
.
Parameter
(
torch
.
Tensor
(
size
=
(
num_heads
,
out_dim
,
1
)))
self
.
attn_l
=
nn
.
Parameter
(
torch
.
Tensor
(
size
=
(
num_heads
,
out_dim
,
1
)))
self
.
attn_r
=
nn
.
Parameter
(
torch
.
Tensor
(
size
=
(
num_heads
,
out_dim
,
1
)))
self
.
attn_r
=
nn
.
Parameter
(
torch
.
Tensor
(
size
=
(
num_heads
,
out_dim
,
1
)))
nn
.
init
.
xavier_normal_
(
self
.
fc
.
weight
.
data
,
gain
=
1.414
)
nn
.
init
.
xavier_normal_
(
self
.
fc
.
weight
.
data
,
gain
=
1.414
)
...
@@ -60,22 +60,19 @@ class GraphAttention(nn.Module):
...
@@ -60,22 +60,19 @@ class GraphAttention(nn.Module):
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
# prepare
# prepare
h
=
inputs
# NxD
h
=
self
.
feat_drop
(
inputs
)
# NxD
if
self
.
feat_drop
:
h
=
self
.
feat_drop
(
h
)
ft
=
self
.
fc
(
h
).
reshape
((
h
.
shape
[
0
],
self
.
num_heads
,
-
1
))
# NxHxD'
ft
=
self
.
fc
(
h
).
reshape
((
h
.
shape
[
0
],
self
.
num_heads
,
-
1
))
# NxHxD'
head_ft
=
ft
.
transpose
(
0
,
1
)
# HxNxD'
head_ft
=
ft
.
transpose
(
0
,
1
)
# HxNxD'
a1
=
torch
.
bmm
(
head_ft
,
self
.
attn_l
).
transpose
(
0
,
1
)
# NxHx1
a1
=
torch
.
bmm
(
head_ft
,
self
.
attn_l
).
transpose
(
0
,
1
)
# NxHx1
a2
=
torch
.
bmm
(
head_ft
,
self
.
attn_r
).
transpose
(
0
,
1
)
# NxHx1
a2
=
torch
.
bmm
(
head_ft
,
self
.
attn_r
).
transpose
(
0
,
1
)
# NxHx1
if
self
.
feat_drop
:
ft
=
self
.
feat_drop
(
ft
)
self
.
g
.
ndata
.
update
({
'ft'
:
ft
,
'a1'
:
a1
,
'a2'
:
a2
})
self
.
g
.
ndata
.
update
({
'ft'
:
ft
,
'a1'
:
a1
,
'a2'
:
a2
})
# 1. compute edge attention
# 1. compute edge attention
self
.
g
.
apply_edges
(
self
.
edge_attention
)
self
.
g
.
apply_edges
(
self
.
edge_attention
)
# 2. compute two results: one is the node features scaled by the dropped,
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
# unnormalized attention values; another is the normalizer of the attention values.
self
.
edge_softmax
()
self
.
g
.
update_all
([
fn
.
src_mul_edge
(
'ft'
,
'a_drop'
,
'ft'
),
fn
.
copy_edge
(
'a'
,
'a'
)],
# 2. compute the aggregated node features scaled by the dropped,
[
fn
.
sum
(
'ft'
,
'ft'
),
fn
.
sum
(
'a'
,
'z'
)])
# unnormalized attention values.
self
.
g
.
update_all
(
fn
.
src_mul_edge
(
'ft'
,
'a_drop'
,
'ft'
),
fn
.
sum
(
'ft'
,
'ft'
))
# 3. apply normalizer
# 3. apply normalizer
ret
=
self
.
g
.
ndata
[
'ft'
]
/
self
.
g
.
ndata
[
'z'
]
# NxHxD'
ret
=
self
.
g
.
ndata
[
'ft'
]
/
self
.
g
.
ndata
[
'z'
]
# NxHxD'
# 4. residual
# 4. residual
...
@@ -90,10 +87,17 @@ class GraphAttention(nn.Module):
...
@@ -90,10 +87,17 @@ class GraphAttention(nn.Module):
def
edge_attention
(
self
,
edges
):
def
edge_attention
(
self
,
edges
):
# an edge UDF to compute unnormalized attention values from src and dst
# an edge UDF to compute unnormalized attention values from src and dst
a
=
self
.
leaky_relu
(
edges
.
src
[
'a1'
]
+
edges
.
dst
[
'a2'
])
a
=
self
.
leaky_relu
(
edges
.
src
[
'a1'
]
+
edges
.
dst
[
'a2'
])
a
=
torch
.
exp
(
a
).
clamp
(
-
10
,
10
)
# use clamp to avoid overflow
return
{
'a'
:
a
}
if
self
.
attn_drop
:
a_drop
=
self
.
attn_drop
(
a
)
def
edge_softmax
(
self
):
return
{
'a'
:
a
,
'a_drop'
:
a_drop
}
# compute the max
self
.
g
.
update_all
(
fn
.
copy_edge
(
'a'
,
'a'
),
fn
.
max
(
'a'
,
'a_max'
))
# minus the max and exp
self
.
g
.
apply_edges
(
lambda
edges
:
{
'a'
:
torch
.
exp
(
edges
.
data
[
'a'
]
-
edges
.
dst
[
'a_max'
])})
# compute dropout
self
.
g
.
apply_edges
(
lambda
edges
:
{
'a_drop'
:
self
.
attn_drop
(
edges
.
data
[
'a'
])})
# compute normalizer
self
.
g
.
update_all
(
fn
.
copy_edge
(
'a'
,
'a'
),
fn
.
sum
(
'a'
,
'z'
))
class
GAT
(
nn
.
Module
):
class
GAT
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -247,7 +251,7 @@ if __name__ == '__main__':
...
@@ -247,7 +251,7 @@ if __name__ == '__main__':
register_data_args
(
parser
)
register_data_args
(
parser
)
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=-
1
,
help
=
"which GPU to use. Set -1 to use CPU."
)
help
=
"which GPU to use. Set -1 to use CPU."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
3
00
,
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
2
00
,
help
=
"number of training epochs"
)
help
=
"number of training epochs"
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
,
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
,
help
=
"number of hidden attention heads"
)
help
=
"number of hidden attention heads"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment