Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6e8f7605
Unverified
Commit
6e8f7605
authored
Aug 19, 2020
by
Mufei Li
Committed by
GitHub
Aug 19, 2020
Browse files
Update (#2057)
parent
75e89a15
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
examples/pytorch/han/model_hetero.py
examples/pytorch/han/model_hetero.py
+2
-1
examples/pytorch/han/utils.py
examples/pytorch/han/utils.py
+3
-3
No files found.
examples/pytorch/han/model_hetero.py
View file @
6e8f7605
...
@@ -62,7 +62,8 @@ class HANLayer(nn.Module):
...
@@ -62,7 +62,8 @@ class HANLayer(nn.Module):
self
.
gat_layers
=
nn
.
ModuleList
()
self
.
gat_layers
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
meta_paths
)):
for
i
in
range
(
len
(
meta_paths
)):
self
.
gat_layers
.
append
(
GATConv
(
in_size
,
out_size
,
layer_num_heads
,
self
.
gat_layers
.
append
(
GATConv
(
in_size
,
out_size
,
layer_num_heads
,
dropout
,
dropout
,
activation
=
F
.
elu
))
dropout
,
dropout
,
activation
=
F
.
elu
,
allow_zero_in_degree
=
True
))
self
.
semantic_attention
=
SemanticAttention
(
in_size
=
out_size
*
layer_num_heads
)
self
.
semantic_attention
=
SemanticAttention
(
in_size
=
out_size
*
layer_num_heads
)
self
.
meta_paths
=
list
(
tuple
(
meta_path
)
for
meta_path
in
meta_paths
)
self
.
meta_paths
=
list
(
tuple
(
meta_path
)
for
meta_path
in
meta_paths
)
...
...
examples/pytorch/han/utils.py
View file @
6e8f7605
...
@@ -99,7 +99,7 @@ def setup(args):
...
@@ -99,7 +99,7 @@ def setup(args):
args
.
update
(
default_configure
)
args
.
update
(
default_configure
)
set_random_seed
(
args
[
'seed'
])
set_random_seed
(
args
[
'seed'
])
args
[
'dataset'
]
=
'ACMRaw'
if
args
[
'hetero'
]
else
'ACM'
args
[
'dataset'
]
=
'ACMRaw'
if
args
[
'hetero'
]
else
'ACM'
args
[
'device'
]
=
'cuda:
0'
if
torch
.
cuda
.
is_available
()
else
'cpu'
args
[
'device'
]
=
'cuda:0'
if
torch
.
cuda
.
is_available
()
else
'cpu'
args
[
'log_dir'
]
=
setup_log_dir
(
args
)
args
[
'log_dir'
]
=
setup_log_dir
(
args
)
return
args
return
args
...
@@ -107,7 +107,7 @@ def setup_for_sampling(args):
...
@@ -107,7 +107,7 @@ def setup_for_sampling(args):
args
.
update
(
default_configure
)
args
.
update
(
default_configure
)
args
.
update
(
sampling_configure
)
args
.
update
(
sampling_configure
)
set_random_seed
()
set_random_seed
()
args
[
'device'
]
=
'cuda:
0'
if
torch
.
cuda
.
is_available
()
else
'cpu'
args
[
'device'
]
=
'cuda:0'
if
torch
.
cuda
.
is_available
()
else
'cpu'
args
[
'log_dir'
]
=
setup_log_dir
(
args
,
sampling
=
True
)
args
[
'log_dir'
]
=
setup_log_dir
(
args
,
sampling
=
True
)
return
args
return
args
...
@@ -188,7 +188,7 @@ def load_acm_raw(remove_self_loop):
...
@@ -188,7 +188,7 @@ def load_acm_raw(remove_self_loop):
hg
=
dgl
.
heterograph
({
hg
=
dgl
.
heterograph
({
(
'paper'
,
'pa'
,
'author'
):
p_vs_a
.
nonzero
(),
(
'paper'
,
'pa'
,
'author'
):
p_vs_a
.
nonzero
(),
(
'author'
,
'ap'
,
'paper'
):
p_vs_a
.
transpose
.
nonzero
(),
(
'author'
,
'ap'
,
'paper'
):
p_vs_a
.
transpose
()
.
nonzero
(),
(
'paper'
,
'pf'
,
'field'
):
p_vs_l
.
nonzero
(),
(
'paper'
,
'pf'
,
'field'
):
p_vs_l
.
nonzero
(),
(
'field'
,
'fp'
,
'paper'
):
p_vs_l
.
transpose
().
nonzero
()
(
'field'
,
'fp'
,
'paper'
):
p_vs_l
.
transpose
().
nonzero
()
})
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment