Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a303f078
"src/array/vscode:/vscode.git/clone" did not exist on "314cedc1b1c3c5ffd2ee0a980010b62faf120f1f"
Unverified
Commit
a303f078
authored
Jun 14, 2021
by
Jinjing Zhou
Committed by
GitHub
Jun 14, 2021
Browse files
fix #2952 (#3010)
parent
17141dd3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
8 deletions
+29
-8
python/dgl/data/tu.py
python/dgl/data/tu.py
+23
-8
tests/compute/test_data.py
tests/compute/test_data.py
+6
-0
No files found.
python/dgl/data/tu.py
View file @
a303f078
...
@@ -93,8 +93,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -93,8 +93,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_indicator
=
self
.
_idx_from_zero
(
DS_indicator
=
self
.
_idx_from_zero
(
np
.
genfromtxt
(
self
.
_file_path
(
"graph_indicator"
),
dtype
=
int
))
np
.
genfromtxt
(
self
.
_file_path
(
"graph_indicator"
),
dtype
=
int
))
if
os
.
path
.
exists
(
self
.
_file_path
(
"graph_labels"
)):
DS_graph_labels
=
self
.
_idx_from_zero
(
DS_graph_labels
=
self
.
_idx_from_zero
(
np
.
genfromtxt
(
self
.
_file_path
(
"graph_labels"
),
dtype
=
int
))
np
.
genfromtxt
(
self
.
_file_path
(
"graph_labels"
),
dtype
=
int
))
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
DS_graph_labels
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
DS_graph_labels
=
np
.
genfromtxt
(
self
.
_file_path
(
"graph_attributes"
),
dtype
=
float
)
self
.
num_labels
=
None
self
.
graph_labels
=
DS_graph_labels
else
:
raise
Exception
(
"Unknown graph label or graph attributes"
)
g
=
dgl_graph
(([],
[]))
g
=
dgl_graph
(([],
[]))
g
.
add_nodes
(
int
(
DS_edge_list
.
max
())
+
1
)
g
.
add_nodes
(
int
(
DS_edge_list
.
max
())
+
1
)
...
@@ -109,8 +118,6 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -109,8 +118,6 @@ class LegacyTUDataset(DGLBuiltinDataset):
self
.
max_num_node
=
len
(
node_idx
[
0
])
self
.
max_num_node
=
len
(
node_idx
[
0
])
self
.
graph_lists
=
[
g
.
subgraph
(
node_idx
)
for
node_idx
in
node_idx_list
]
self
.
graph_lists
=
[
g
.
subgraph
(
node_idx
)
for
node_idx
in
node_idx_list
]
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
DS_graph_labels
try
:
try
:
DS_node_labels
=
self
.
_idx_from_zero
(
DS_node_labels
=
self
.
_idx_from_zero
(
...
@@ -296,8 +303,18 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -296,8 +303,18 @@ class TUDataset(DGLBuiltinDataset):
loadtxt
(
self
.
_file_path
(
"A"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"A"
),
delimiter
=
","
).
astype
(
int
))
DS_indicator
=
self
.
_idx_from_zero
(
DS_indicator
=
self
.
_idx_from_zero
(
loadtxt
(
self
.
_file_path
(
"graph_indicator"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"graph_indicator"
),
delimiter
=
","
).
astype
(
int
))
if
os
.
path
.
exists
(
self
.
_file_path
(
"graph_labels"
)):
DS_graph_labels
=
self
.
_idx_reset
(
DS_graph_labels
=
self
.
_idx_reset
(
loadtxt
(
self
.
_file_path
(
"graph_labels"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"graph_labels"
),
delimiter
=
","
).
astype
(
int
))
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
DS_graph_labels
=
loadtxt
(
self
.
_file_path
(
"graph_attributes"
),
delimiter
=
","
).
astype
(
float
)
self
.
num_labels
=
None
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
else
:
raise
Exception
(
"Unknown graph label or graph attributes"
)
g
=
dgl_graph
(([],
[]))
g
=
dgl_graph
(([],
[]))
g
.
add_nodes
(
int
(
DS_edge_list
.
max
())
+
1
)
g
.
add_nodes
(
int
(
DS_edge_list
.
max
())
+
1
)
...
@@ -311,8 +328,6 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -311,8 +328,6 @@ class TUDataset(DGLBuiltinDataset):
if
len
(
node_idx
[
0
])
>
self
.
max_num_node
:
if
len
(
node_idx
[
0
])
>
self
.
max_num_node
:
self
.
max_num_node
=
len
(
node_idx
[
0
])
self
.
max_num_node
=
len
(
node_idx
[
0
])
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
self
.
attr_dict
=
{
self
.
attr_dict
=
{
'node_labels'
:
(
'ndata'
,
'node_labels'
),
'node_labels'
:
(
'ndata'
,
'node_labels'
),
...
...
tests/compute/test_data.py
View file @
a303f078
...
@@ -24,6 +24,12 @@ def test_gin():
...
@@ -24,6 +24,12 @@ def test_gin():
assert
len
(
ds
)
==
n_graphs
,
(
len
(
ds
),
name
)
assert
len
(
ds
)
==
n_graphs
,
(
len
(
ds
),
name
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_tudataset_regression
():
ds
=
data
.
TUDataset
(
'ZINC_test'
,
force_reload
=
True
)
assert
len
(
ds
)
==
5000
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_data_hash
():
def
test_data_hash
():
class
HashTestDataset
(
data
.
DGLDataset
):
class
HashTestDataset
(
data
.
DGLDataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment