Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
51ba6621
Unverified
Commit
51ba6621
authored
Aug 26, 2020
by
Jinjing Zhou
Committed by
GitHub
Aug 26, 2020
Browse files
Fix GINDT and #2087 (#2103)
* fix gindt * ff * fix * minor fix * fix
parent
628d9fc5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
23 deletions
+34
-23
python/dgl/data/gindt.py
python/dgl/data/gindt.py
+31
-20
python/dgl/data/graph_serialize.py
python/dgl/data/graph_serialize.py
+3
-3
No files found.
python/dgl/data/gindt.py
View file @
51ba6621
...
...
@@ -15,6 +15,7 @@ from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, down
from
..utils
import
retry_method_with_fix
from
..convert
import
graph
as
dgl_graph
class
GINDataset
(
DGLBuiltinDataset
):
"""Datasets for Graph Isomorphism Network (GIN)
Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
...
...
@@ -232,32 +233,36 @@ class GINDataset(DGLBuiltinDataset):
if
self
.
degree_as_nlabel
:
if
self
.
verbose
:
print
(
'generate node features by node degree...'
)
nlabel_set
=
set
([])
for
g
in
self
.
graphs
:
# actually this label shouldn't be updated
# in case users want to keep it
# but usually no features means no labels, fine.
g
.
ndata
[
'label'
]
=
g
.
in_degrees
()
# extracting unique node labels
nlabel_set
=
nlabel_set
.
union
(
set
([
F
.
as_scalar
(
nl
)
for
nl
in
g
.
ndata
[
'label'
]]))
nlabel_set
=
list
(
nlabel_set
)
# in case the labels/degrees are not continuous number
self
.
ndegree_dict
=
{
# in case the labels/degrees are not continuous number
nlabel_set
=
set
([])
for
g
in
self
.
graphs
:
nlabel_set
=
nlabel_set
.
union
(
set
([
F
.
as_scalar
(
nl
)
for
nl
in
g
.
ndata
[
'label'
]]))
nlabel_set
=
list
(
nlabel_set
)
if
len
(
nlabel_set
)
==
np
.
max
(
nlabel_set
)
+
1
and
np
.
min
(
nlabel_set
)
==
0
:
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
label2idx
=
self
.
nlabel_dict
else
:
label2idx
=
{
nlabel_set
[
i
]:
i
for
i
in
range
(
len
(
nlabel_set
))
}
label2idx
=
self
.
ndegree_dict
# generate node attr by node label
else
:
if
self
.
verbose
:
print
(
'generate node features by node label...'
)
label2idx
=
self
.
nlabel_dict
for
g
in
self
.
graphs
:
g
.
ndata
[
'attr'
]
=
F
.
tensor
(
np
.
zeros
((
g
.
number_of_nodes
(),
len
(
label2idx
))))
g
.
ndata
[
'attr'
][
range
(
g
.
number_of_nodes
()),
[
label2idx
[
F
.
as_scalar
(
F
.
reshape
(
nl
,
(
1
,)))]
for
nl
in
g
.
ndata
[
'label'
]]]
=
1
attr
=
np
.
zeros
((
g
.
number_of_nodes
(),
len
(
label2idx
)))
attr
[
range
(
g
.
number_of_nodes
()),
[
label2idx
[
nl
]
for
nl
in
F
.
asnumpy
(
g
.
ndata
[
'label'
]).
tolist
()]]
=
1
g
.
ndata
[
'attr'
]
=
F
.
tensor
(
attr
)
# after load, get the #classes and #dim
self
.
gclasses
=
len
(
self
.
glabel_dict
)
...
...
@@ -288,8 +293,10 @@ class GINDataset(DGLBuiltinDataset):
self
.
nlabel_dict
,
self
.
ndegree_dict
))
def
save
(
self
):
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
label_dict
=
{
'labels'
:
self
.
labels
}
info_dict
=
{
'N'
:
self
.
N
,
'n'
:
self
.
n
,
...
...
@@ -308,8 +315,10 @@ class GINDataset(DGLBuiltinDataset):
save_info
(
str
(
info_path
),
info_dict
)
def
load
(
self
):
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
graphs
,
label_dict
=
load_graphs
(
str
(
graph_path
))
info_dict
=
load_info
(
str
(
info_path
))
...
...
@@ -331,8 +340,10 @@ class GINDataset(DGLBuiltinDataset):
self
.
degree_as_nlabel
=
info_dict
[
'degree_as_nlabel'
]
def
has_cache
(
self
):
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
graph_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.bin'
.
format
(
self
.
name
,
self
.
hash
))
info_path
=
os
.
path
.
join
(
self
.
save_path
,
'gin_{}_{}.pkl'
.
format
(
self
.
name
,
self
.
hash
))
if
os
.
path
.
exists
(
graph_path
)
and
os
.
path
.
exists
(
info_path
):
return
True
return
False
python/dgl/data/graph_serialize.py
View file @
51ba6621
...
...
@@ -108,11 +108,11 @@ def save_graphs(filename, g_list, labels=None):
load_graphs
"""
# if it is local file, do some sanity check
if
filename
.
startswith
(
's3://'
)
is
False
:
if
not
filename
.
startswith
(
's3://'
):
if
os
.
path
.
isdir
(
filename
):
raise
DGLError
(
"Filename {} is an existing directory."
.
format
(
filename
))
f_path
,
_
=
os
.
path
.
split
(
filename
)
if
not
os
.
path
.
exists
(
f_path
):
f_path
=
os
.
path
.
dirname
(
filename
)
if
f_path
and
not
os
.
path
.
exists
(
f_path
):
os
.
makedirs
(
f_path
)
g_sample
=
g_list
[
0
]
if
isinstance
(
g_list
,
list
)
else
g_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment