Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7ea777e1
Unverified
Commit
7ea777e1
authored
Jul 11, 2020
by
Mufei Li
Committed by
GitHub
Jul 11, 2020
Browse files
[Example] Fix HAN (#1790)
* Fix * Fix * Fix
parent
8a183e3f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
6 deletions
+13
-6
examples/pytorch/han/main.py
examples/pytorch/han/main.py
+5
-0
examples/pytorch/han/model.py
examples/pytorch/han/model.py
+4
-3
examples/pytorch/han/model_hetero.py
examples/pytorch/han/model_hetero.py
+4
-3
No files found.
examples/pytorch/han/main.py
View file @
7ea777e1
...
@@ -29,6 +29,11 @@ def main(args):
...
@@ -29,6 +29,11 @@ def main(args):
g
,
features
,
labels
,
num_classes
,
train_idx
,
val_idx
,
test_idx
,
train_mask
,
\
g
,
features
,
labels
,
num_classes
,
train_idx
,
val_idx
,
test_idx
,
train_mask
,
\
val_mask
,
test_mask
=
load_data
(
args
[
'dataset'
])
val_mask
,
test_mask
=
load_data
(
args
[
'dataset'
])
if
hasattr
(
torch
,
'BoolTensor'
):
train_mask
=
train_mask
.
bool
()
val_mask
=
val_mask
.
bool
()
test_mask
=
test_mask
.
bool
()
features
=
features
.
to
(
args
[
'device'
])
features
=
features
.
to
(
args
[
'device'
])
labels
=
labels
.
to
(
args
[
'device'
])
labels
=
labels
.
to
(
args
[
'device'
])
train_mask
=
train_mask
.
to
(
args
[
'device'
])
train_mask
=
train_mask
.
to
(
args
[
'device'
])
...
...
examples/pytorch/han/model.py
View file @
7ea777e1
...
@@ -15,10 +15,11 @@ class SemanticAttention(nn.Module):
...
@@ -15,10 +15,11 @@ class SemanticAttention(nn.Module):
)
)
def
forward
(
self
,
z
):
def
forward
(
self
,
z
):
w
=
self
.
project
(
z
)
w
=
self
.
project
(
z
).
mean
(
0
)
# (M, 1)
beta
=
torch
.
softmax
(
w
,
dim
=
1
)
beta
=
torch
.
softmax
(
w
,
dim
=
0
)
# (M, 1)
beta
=
beta
.
expand
((
z
.
shape
[
0
],)
+
beta
.
shape
)
# (N, M, 1)
return
(
beta
*
z
).
sum
(
1
)
return
(
beta
*
z
).
sum
(
1
)
# (N, D * K)
class
HANLayer
(
nn
.
Module
):
class
HANLayer
(
nn
.
Module
):
"""
"""
...
...
examples/pytorch/han/model_hetero.py
View file @
7ea777e1
...
@@ -25,10 +25,11 @@ class SemanticAttention(nn.Module):
...
@@ -25,10 +25,11 @@ class SemanticAttention(nn.Module):
)
)
def
forward
(
self
,
z
):
def
forward
(
self
,
z
):
w
=
self
.
project
(
z
)
w
=
self
.
project
(
z
).
mean
(
0
)
# (M, 1)
beta
=
torch
.
softmax
(
w
,
dim
=
1
)
beta
=
torch
.
softmax
(
w
,
dim
=
0
)
# (M, 1)
beta
=
beta
.
expand
((
z
.
shape
[
0
],)
+
beta
.
shape
)
# (N, M, 1)
return
(
beta
*
z
).
sum
(
1
)
return
(
beta
*
z
).
sum
(
1
)
# (N, D * K)
class
HANLayer
(
nn
.
Module
):
class
HANLayer
(
nn
.
Module
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment