Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8b8fd2c0
Unverified
Commit
8b8fd2c0
authored
Feb 15, 2022
by
Mufei Li
Committed by
GitHub
Feb 15, 2022
Browse files
[Dataset] Add transform argument to built-in datasets (#3733)
* Update * Fix * Update
parent
b3d3a2c4
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
19 deletions
+87
-19
python/dgl/data/tu.py
python/dgl/data/tu.py
+27
-15
tests/compute/test_data.py
tests/compute/test_data.py
+60
-4
No files found.
python/dgl/data/tu.py
View file @
8b8fd2c0
...
@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters
Parameters
----------
----------
name : str
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool
use_pandas : bool
Numpy's file read function has performance issue when file is large,
Numpy's file read function has performance issue when file is large,
...
@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
max_allow_node : int
max_allow_node : int
Remove graphs that contains more nodes than ``max_allow_node``.
Remove graphs that contains more nodes than ``max_allow_node``.
Default : None
Default : None
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
Attributes
----------
----------
...
@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
If neither labels provided, it uses constant for node feature.
If neither labels provided, it uses constant for node feature.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Shuffle is preferred before manual train/val split.
Examples
Examples
...
@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
def
__init__
(
self
,
name
,
use_pandas
=
False
,
def
__init__
(
self
,
name
,
use_pandas
=
False
,
hidden_size
=
10
,
max_allow_node
=
None
,
hidden_size
=
10
,
max_allow_node
=
None
,
raw_dir
=
None
,
force_reload
=
False
,
verbose
=
False
):
raw_dir
=
None
,
force_reload
=
False
,
verbose
=
False
,
transform
=
None
):
url
=
self
.
_url
.
format
(
name
)
url
=
self
.
_url
.
format
(
name
)
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
self
.
use_pandas
=
use_pandas
self
.
use_pandas
=
use_pandas
super
(
LegacyTUDataset
,
self
).
__init__
(
name
=
name
,
url
=
url
,
raw_dir
=
raw_dir
,
super
(
LegacyTUDataset
,
self
).
__init__
(
name
=
name
,
url
=
url
,
raw_dir
=
raw_dir
,
hash_key
=
(
name
,
use_pandas
,
hidden_size
,
max_allow_node
),
hash_key
=
(
name
,
use_pandas
,
hidden_size
,
max_allow_node
),
force_reload
=
force_reload
,
verbose
=
verbose
)
force_reload
=
force_reload
,
verbose
=
verbose
,
transform
=
transform
)
def
process
(
self
):
def
process
(
self
):
self
.
data_mode
=
None
self
.
data_mode
=
None
...
@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_graph_labels
=
self
.
_idx_from_zero
(
DS_graph_labels
=
self
.
_idx_from_zero
(
np
.
genfromtxt
(
self
.
_file_path
(
"graph_labels"
),
dtype
=
int
))
np
.
genfromtxt
(
self
.
_file_path
(
"graph_labels"
),
dtype
=
int
))
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
DS_graph_labels
self
.
graph_labels
=
DS_graph_labels
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
DS_graph_labels
=
np
.
genfromtxt
(
self
.
_file_path
(
"graph_attributes"
),
dtype
=
float
)
DS_graph_labels
=
np
.
genfromtxt
(
self
.
_file_path
(
"graph_attributes"
),
dtype
=
float
)
self
.
num_labels
=
None
self
.
num_labels
=
None
...
@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
...
@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
And its label.
And its label.
"""
"""
g
=
self
.
graph_lists
[
idx
]
g
=
self
.
graph_lists
[
idx
]
if
self
.
_transform
is
not
None
:
g
=
self
.
_transform
(
g
)
return
g
,
self
.
graph_labels
[
idx
]
return
g
,
self
.
graph_labels
[
idx
]
def
__len__
(
self
):
def
__len__
(
self
):
...
@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
Parameters
Parameters
----------
----------
name : str
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
Attributes
----------
----------
...
@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
:math:`\lbrace 0, 2 \rbrace`.
:math:`\lbrace 0, 2 \rbrace`.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Shuffle is preferred before manual train/val split.
Examples
Examples
...
@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
Graph(num_nodes=9539, num_edges=47382,
Graph(num_nodes=9539, num_edges=47382,
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
"""
"""
_url
=
r
"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
_url
=
r
"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def
__init__
(
self
,
name
,
raw_dir
=
None
,
force_reload
=
False
,
verbose
=
False
):
def
__init__
(
self
,
name
,
raw_dir
=
None
,
force_reload
=
False
,
verbose
=
False
,
transform
=
None
):
url
=
self
.
_url
.
format
(
name
)
url
=
self
.
_url
.
format
(
name
)
super
(
TUDataset
,
self
).
__init__
(
name
=
name
,
url
=
url
,
super
(
TUDataset
,
self
).
__init__
(
name
=
name
,
url
=
url
,
raw_dir
=
raw_dir
,
force_reload
=
force_reload
,
raw_dir
=
raw_dir
,
force_reload
=
force_reload
,
verbose
=
verbose
)
verbose
=
verbose
,
transform
=
transform
)
def
process
(
self
):
def
process
(
self
):
DS_edge_list
=
self
.
_idx_from_zero
(
DS_edge_list
=
self
.
_idx_from_zero
(
loadtxt
(
self
.
_file_path
(
"A"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"A"
),
delimiter
=
","
).
astype
(
int
))
DS_indicator
=
self
.
_idx_from_zero
(
DS_indicator
=
self
.
_idx_from_zero
(
loadtxt
(
self
.
_file_path
(
"graph_indicator"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"graph_indicator"
),
delimiter
=
","
).
astype
(
int
))
if
os
.
path
.
exists
(
self
.
_file_path
(
"graph_labels"
)):
if
os
.
path
.
exists
(
self
.
_file_path
(
"graph_labels"
)):
DS_graph_labels
=
self
.
_idx_reset
(
DS_graph_labels
=
self
.
_idx_reset
(
loadtxt
(
self
.
_file_path
(
"graph_labels"
),
delimiter
=
","
).
astype
(
int
))
loadtxt
(
self
.
_file_path
(
"graph_labels"
),
delimiter
=
","
).
astype
(
int
))
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
num_labels
=
max
(
DS_graph_labels
)
+
1
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
elif
os
.
path
.
exists
(
self
.
_file_path
(
"graph_attributes"
)):
DS_graph_labels
=
loadtxt
(
self
.
_file_path
(
"graph_attributes"
),
delimiter
=
","
).
astype
(
float
)
DS_graph_labels
=
loadtxt
(
self
.
_file_path
(
"graph_attributes"
),
delimiter
=
","
).
astype
(
float
)
self
.
num_labels
=
None
self
.
num_labels
=
None
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
self
.
graph_labels
=
F
.
tensor
(
DS_graph_labels
)
else
:
else
:
raise
Exception
(
"Unknown graph label or graph attributes"
)
raise
Exception
(
"Unknown graph label or graph attributes"
)
...
@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
...
@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
And its label.
And its label.
"""
"""
g
=
self
.
graph_lists
[
idx
]
g
=
self
.
graph_lists
[
idx
]
if
self
.
_transform
is
not
None
:
g
=
self
.
_transform
(
g
)
return
g
,
self
.
graph_labels
[
idx
]
return
g
,
self
.
graph_labels
[
idx
]
def
__len__
(
self
):
def
__len__
(
self
):
...
...
tests/compute/test_data.py
View file @
8b8fd2c0
...
@@ -7,6 +7,7 @@ import os
...
@@ -7,6 +7,7 @@ import os
import
pandas
as
pd
import
pandas
as
pd
import
yaml
import
yaml
import
pytest
import
pytest
import
dgl
import
dgl.data
as
data
import
dgl.data
as
data
from
dgl
import
DGLError
from
dgl
import
DGLError
import
dgl
import
dgl
...
@@ -16,7 +17,11 @@ def test_minigc():
...
@@ -16,7 +17,11 @@ def test_minigc():
ds
=
data
.
MiniGCDataset
(
16
,
10
,
20
)
ds
=
data
.
MiniGCDataset
(
16
,
10
,
20
)
g
,
l
=
list
(
zip
(
*
ds
))
g
,
l
=
list
(
zip
(
*
ds
))
print
(
g
,
l
)
print
(
g
,
l
)
g1
=
ds
[
0
][
0
]
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
ds
=
data
.
MiniGCDataset
(
16
,
10
,
20
,
transform
=
transform
)
g2
=
ds
[
0
][
0
]
assert
g2
.
num_edges
()
-
g1
.
num_edges
()
==
g1
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_gin
():
def
test_gin
():
...
@@ -27,37 +32,64 @@ def test_gin():
...
@@ -27,37 +32,64 @@ def test_gin():
'PROTEINS'
:
1113
,
'PROTEINS'
:
1113
,
'PTC'
:
344
,
'PTC'
:
344
,
}
}
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
for
name
,
n_graphs
in
ds_n_graphs
.
items
():
for
name
,
n_graphs
in
ds_n_graphs
.
items
():
ds
=
data
.
GINDataset
(
name
,
self_loop
=
False
,
degree_as_nlabel
=
False
)
ds
=
data
.
GINDataset
(
name
,
self_loop
=
False
,
degree_as_nlabel
=
False
)
assert
len
(
ds
)
==
n_graphs
,
(
len
(
ds
),
name
)
assert
len
(
ds
)
==
n_graphs
,
(
len
(
ds
),
name
)
g1
=
ds
[
0
][
0
]
ds
=
data
.
GINDataset
(
name
,
self_loop
=
False
,
degree_as_nlabel
=
False
,
transform
=
transform
)
g2
=
ds
[
0
][
0
]
assert
g2
.
num_edges
()
-
g1
.
num_edges
()
==
g1
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_fraud
():
def
test_fraud
():
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
g
=
data
.
FraudDataset
(
'amazon'
)[
0
]
g
=
data
.
FraudDataset
(
'amazon'
)[
0
]
assert
g
.
num_nodes
()
==
11944
assert
g
.
num_nodes
()
==
11944
num_edges1
=
g
.
num_edges
()
g2
=
data
.
FraudDataset
(
'amazon'
,
transform
=
transform
)[
0
]
# 3 edge types
assert
g2
.
num_edges
()
-
num_edges1
==
g
.
num_nodes
()
*
3
g
=
data
.
FraudAmazonDataset
()[
0
]
g
=
data
.
FraudAmazonDataset
()[
0
]
assert
g
.
num_nodes
()
==
11944
assert
g
.
num_nodes
()
==
11944
g2
=
data
.
FraudAmazonDataset
(
transform
=
transform
)[
0
]
# 3 edge types
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
*
3
g
=
data
.
FraudYelpDataset
()[
0
]
g
=
data
.
FraudYelpDataset
()[
0
]
assert
g
.
num_nodes
()
==
45954
assert
g
.
num_nodes
()
==
45954
g2
=
data
.
FraudYelpDataset
(
transform
=
transform
)[
0
]
# 3 edge types
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
*
3
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_fakenews
():
def
test_fakenews
():
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
ds
=
data
.
FakeNewsDataset
(
'politifact'
,
'bert'
)
ds
=
data
.
FakeNewsDataset
(
'politifact'
,
'bert'
)
assert
len
(
ds
)
==
314
assert
len
(
ds
)
==
314
g
=
ds
[
0
][
0
]
g2
=
data
.
FakeNewsDataset
(
'politifact'
,
'bert'
,
transform
=
transform
)[
0
][
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
ds
=
data
.
FakeNewsDataset
(
'gossipcop'
,
'profile'
)
ds
=
data
.
FakeNewsDataset
(
'gossipcop'
,
'profile'
)
assert
len
(
ds
)
==
5464
assert
len
(
ds
)
==
5464
g
=
ds
[
0
][
0
]
g2
=
data
.
FakeNewsDataset
(
'gossipcop'
,
'profile'
,
transform
=
transform
)[
0
][
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_tudataset_regression
():
def
test_tudataset_regression
():
ds
=
data
.
TUDataset
(
'ZINC_test'
,
force_reload
=
True
)
ds
=
data
.
TUDataset
(
'ZINC_test'
,
force_reload
=
True
)
assert
len
(
ds
)
==
5000
assert
len
(
ds
)
==
5000
g
=
ds
[
0
][
0
]
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
ds
=
data
.
TUDataset
(
'ZINC_test'
,
force_reload
=
True
,
transform
=
transform
)
g2
=
ds
[
0
][
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_data_hash
():
def
test_data_hash
():
...
@@ -78,12 +110,16 @@ def test_data_hash():
...
@@ -78,12 +110,16 @@ def test_data_hash():
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_citation_graph
():
def
test_citation_graph
():
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
# cora
# cora
g
=
data
.
CoraGraphDataset
()[
0
]
g
=
data
.
CoraGraphDataset
()[
0
]
assert
g
.
num_nodes
()
==
2708
assert
g
.
num_nodes
()
==
2708
assert
g
.
num_edges
()
==
10556
assert
g
.
num_edges
()
==
10556
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
CoraGraphDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# Citeseer
# Citeseer
g
=
data
.
CiteseerGraphDataset
()[
0
]
g
=
data
.
CiteseerGraphDataset
()[
0
]
...
@@ -91,6 +127,8 @@ def test_citation_graph():
...
@@ -91,6 +127,8 @@ def test_citation_graph():
assert
g
.
num_edges
()
==
9228
assert
g
.
num_edges
()
==
9228
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
CiteseerGraphDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# Pubmed
# Pubmed
g
=
data
.
PubmedGraphDataset
()[
0
]
g
=
data
.
PubmedGraphDataset
()[
0
]
...
@@ -98,16 +136,22 @@ def test_citation_graph():
...
@@ -98,16 +136,22 @@ def test_citation_graph():
assert
g
.
num_edges
()
==
88651
assert
g
.
num_edges
()
==
88651
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
PubmedGraphDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_gnn_benchmark
():
def
test_gnn_benchmark
():
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
# AmazonCoBuyComputerDataset
# AmazonCoBuyComputerDataset
g
=
data
.
AmazonCoBuyComputerDataset
()[
0
]
g
=
data
.
AmazonCoBuyComputerDataset
()[
0
]
assert
g
.
num_nodes
()
==
13752
assert
g
.
num_nodes
()
==
13752
assert
g
.
num_edges
()
==
491722
assert
g
.
num_edges
()
==
491722
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
AmazonCoBuyComputerDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# AmazonCoBuyPhotoDataset
# AmazonCoBuyPhotoDataset
g
=
data
.
AmazonCoBuyPhotoDataset
()[
0
]
g
=
data
.
AmazonCoBuyPhotoDataset
()[
0
]
...
@@ -115,6 +159,8 @@ def test_gnn_benchmark():
...
@@ -115,6 +159,8 @@ def test_gnn_benchmark():
assert
g
.
num_edges
()
==
238163
assert
g
.
num_edges
()
==
238163
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
AmazonCoBuyPhotoDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# CoauthorPhysicsDataset
# CoauthorPhysicsDataset
g
=
data
.
CoauthorPhysicsDataset
()[
0
]
g
=
data
.
CoauthorPhysicsDataset
()[
0
]
...
@@ -122,6 +168,8 @@ def test_gnn_benchmark():
...
@@ -122,6 +168,8 @@ def test_gnn_benchmark():
assert
g
.
num_edges
()
==
495924
assert
g
.
num_edges
()
==
495924
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
CoauthorPhysicsDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# CoauthorCSDataset
# CoauthorCSDataset
g
=
data
.
CoauthorCSDataset
()[
0
]
g
=
data
.
CoauthorCSDataset
()[
0
]
...
@@ -129,6 +177,8 @@ def test_gnn_benchmark():
...
@@ -129,6 +177,8 @@ def test_gnn_benchmark():
assert
g
.
num_edges
()
==
163788
assert
g
.
num_edges
()
==
163788
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
CoauthorCSDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
# CoraFullDataset
# CoraFullDataset
g
=
data
.
CoraFullDataset
()[
0
]
g
=
data
.
CoraFullDataset
()[
0
]
...
@@ -136,6 +186,8 @@ def test_gnn_benchmark():
...
@@ -136,6 +186,8 @@ def test_gnn_benchmark():
assert
g
.
num_edges
()
==
126842
assert
g
.
num_edges
()
==
126842
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
g2
=
data
.
CoraFullDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
...
@@ -147,6 +199,10 @@ def test_reddit():
...
@@ -147,6 +199,10 @@ def test_reddit():
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
dst
=
F
.
asnumpy
(
g
.
edges
()[
1
])
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
assert
np
.
array_equal
(
dst
,
np
.
sort
(
dst
))
transform
=
dgl
.
AddSelfLoop
(
allow_duplicate
=
True
)
g2
=
data
.
RedditDataset
(
transform
=
transform
)[
0
]
assert
g2
.
num_edges
()
-
g
.
num_edges
()
==
g
.
num_nodes
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"Datasets don't need to be tested on GPU."
)
def
test_extract_archive
():
def
test_extract_archive
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment