Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ad4df9c5
Unverified
Commit
ad4df9c5
authored
Oct 12, 2023
by
paoxiaode
Committed by
GitHub
Oct 12, 2023
Browse files
[Dataset] update docstring of LRGB (#6430)
parent
ec5c515c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
78 deletions
+122
-78
python/dgl/data/lrgb.py
python/dgl/data/lrgb.py
+122
-78
No files found.
python/dgl/data/lrgb.py
View file @
ad4df9c5
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
pickle
import
pickle
import
pandas
as
pd
import
pandas
as
pd
from
ogb.utils
import
smiles2graph
from
ogb.utils
import
smiles2graph
as
smiles2graph_OGB
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..
import
backend
as
F
from
..
import
backend
as
F
...
@@ -23,21 +23,22 @@ from .utils import (
...
@@ -23,21 +23,22 @@ from .utils import (
class
PeptidesStructuralDataset
(
DGLDataset
):
class
PeptidesStructuralDataset
(
DGLDataset
):
r
"""Peptides structure dataset for the graph regression task.
r
"""Peptides structure dataset for the graph regression task.
DGL dataset of
15,535 small peptides represented as their molecular
DGL dataset of
Peptides-struct in the LRGB benchmark which contains
graph (SMILES) with 11
re
g
res
sion targets derived from the peptide's
15,535 small peptides
re
p
res
ented as their molecular graph (SMILES)
3D structure.
with 11 regression targets derived from the peptide's
3D structure.
The 11 regression targets were precomputed from molecules' 3D structure:
The 11 regression targets were precomputed from molecules' 3D structure:
Inertia_mass_[a-c]: The principal component of the inertia of the
- Inertia_mass_[a-c]: The principal component of the inertia of the
mass, with some normalizations. (Sorted)
mass, with some normalizations. (Sorted)
Inertia_valence_[a-c]: The principal component of the inertia of the
-
Inertia_valence_[a-c]: The principal component of the inertia of the
Hydrogen atoms. This is basically a measure of the 3D
Hydrogen atoms. This is basically a measure of the 3D
distribution of hydrogens. (Sorted)
distribution of hydrogens. (Sorted)
length_[a-c]: The length around the 3 main geometric axis of
-
length_[a-c]: The length around the 3 main geometric axis of
the 3D objects (without considering atom types). (Sorted)
the 3D objects (without considering atom types). (Sorted)
Spherocity: SpherocityIndex descriptor computed by
-
Spherocity: SpherocityIndex descriptor computed by
rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex
rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex
Plane_best_fit: Plane of best fit (PBF) descriptor computed by
-
Plane_best_fit: Plane of best fit (PBF) descriptor computed by
rdkit.Chem.rdMolDescriptors.CalcPBF
rdkit.Chem.rdMolDescriptors.CalcPBF
Reference `<https://arxiv.org/abs/2206.08164.pdf>`_
Reference `<https://arxiv.org/abs/2206.08164.pdf>`_
...
@@ -87,7 +88,8 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -87,7 +88,8 @@ class PeptidesStructuralDataset(DGLDataset):
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> # get train dataset
>>> # get train dataset
>>> split_dict = dataset.get_idx_split()
>>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]]
>>> trainset = dataset[split_dict["train"]]
...
@@ -114,7 +116,7 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -114,7 +116,7 @@ class PeptidesStructuralDataset(DGLDataset):
force_reload
=
None
,
force_reload
=
None
,
verbose
=
None
,
verbose
=
None
,
transform
=
None
,
transform
=
None
,
smiles2graph
=
smiles2graph
,
smiles2graph
=
smiles2graph
_OGB
,
):
):
self
.
smiles2graph
=
smiles2graph
self
.
smiles2graph
=
smiles2graph
# MD5 hash of the dataset file.
# MD5 hash of the dataset file.
...
@@ -123,8 +125,10 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -123,8 +125,10 @@ class PeptidesStructuralDataset(DGLDataset):
https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1
https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1
"""
"""
self
.
md5sum_stratified_split
=
"5a0114bdadc80b94fc7ae974f13ef061"
self
.
md5sum_stratified_split
=
"5a0114bdadc80b94fc7ae974f13ef061"
self
.
graphs
=
[]
self
.
labels
=
[]
super
(
PeptidesStructuralDataset
,
self
).
__init__
(
super
().
__init__
(
name
=
"Peptides-struc"
,
name
=
"Peptides-struc"
,
raw_dir
=
raw_dir
,
raw_dir
=
raw_dir
,
url
=
"""
url
=
"""
...
@@ -137,40 +141,45 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -137,40 +141,45 @@ class PeptidesStructuralDataset(DGLDataset):
@
property
@
property
def
raw_data_path
(
self
):
def
raw_data_path
(
self
):
r
"""Path to save the raw dataset file."""
return
os
.
path
.
join
(
self
.
raw_path
,
"peptide_structure_dataset.csv.gz"
)
return
os
.
path
.
join
(
self
.
raw_path
,
"peptide_structure_dataset.csv.gz"
)
@
property
@
property
def
split_data_path
(
self
):
def
split_data_path
(
self
):
r
"""Path to save the dataset split file."""
return
os
.
path
.
join
(
return
os
.
path
.
join
(
self
.
raw_path
,
"splits_random_stratified_peptide_structure.pickle"
self
.
raw_path
,
"splits_random_stratified_peptide_structure.pickle"
)
)
@
property
@
property
def
graph_path
(
self
):
def
graph_path
(
self
):
r
"""Path to save the processed dataset file."""
return
os
.
path
.
join
(
self
.
save_path
,
"Peptides-struc.bin"
)
return
os
.
path
.
join
(
self
.
save_path
,
"Peptides-struc.bin"
)
@
property
@
property
def
num_atom_types
(
self
):
def
num_atom_types
(
self
):
r
"""Number of atom types."""
return
9
return
9
@
property
@
property
def
num_bond_types
(
self
):
def
num_bond_types
(
self
):
r
"""Number of bond types."""
return
3
return
3
def
_md5sum
(
self
,
path
):
def
_md5sum
(
self
,
path
):
hash_md5
=
hashlib
.
md5
()
hash_md5
=
hashlib
.
md5
()
with
open
(
path
,
"rb"
)
as
f
:
with
open
(
path
,
"rb"
)
as
f
ile
:
buffer
=
f
.
read
()
buffer
=
f
ile
.
read
()
hash_md5
.
update
(
buffer
)
hash_md5
.
update
(
buffer
)
return
hash_md5
.
hexdigest
()
return
hash_md5
.
hexdigest
()
def
download
(
self
):
def
download
(
self
):
path
=
download
(
self
.
url
,
path
=
self
.
raw_data_path
)
path
=
download
(
self
.
url
,
path
=
self
.
raw_data_path
)
# Save to disk the MD5 hash of the downloaded file.
# Save to disk the MD5 hash of the downloaded file.
hash
=
self
.
_md5sum
(
path
)
hash
_data
=
self
.
_md5sum
(
path
)
if
hash
!=
self
.
md5sum_data
:
if
hash
_data
!=
self
.
md5sum_data
:
raise
ValueError
(
"Unexpected MD5 hash of the downloaded file"
)
raise
ValueError
(
"Unexpected MD5 hash of the downloaded file"
)
open
(
os
.
path
.
join
(
self
.
raw_path
,
hash
),
"w"
).
close
()
open
(
os
.
path
.
join
(
self
.
raw_path
,
hash
_data
),
"w"
).
close
()
# Download train/val/test splits.
# Download train/val/test splits.
path_split
=
download
(
path_split
=
download
(
self
.
url_stratified_split
,
path
=
self
.
split_data_path
self
.
url_stratified_split
,
path
=
self
.
split_data_path
...
@@ -201,8 +210,7 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -201,8 +210,7 @@ class PeptidesStructuralDataset(DGLDataset):
)
)
if
self
.
verbose
:
if
self
.
verbose
:
print
(
"Converting SMILES strings into graphs..."
)
print
(
"Converting SMILES strings into graphs..."
)
self
.
graphs
=
[]
self
.
labels
=
[]
for
i
in
tqdm
(
range
(
len
(
smiles_list
))):
for
i
in
tqdm
(
range
(
len
(
smiles_list
))):
smiles
=
smiles_list
[
i
]
smiles
=
smiles_list
[
i
]
y
=
data_df
.
iloc
[
i
][
target_names
]
y
=
data_df
.
iloc
[
i
][
target_names
]
...
@@ -244,8 +252,8 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -244,8 +252,8 @@ class PeptidesStructuralDataset(DGLDataset):
Returns:
Returns:
Dict with 'train', 'val', 'test', splits indices.
Dict with 'train', 'val', 'test', splits indices.
"""
"""
with
open
(
self
.
split_data_path
,
"rb"
)
as
f
:
with
open
(
self
.
split_data_path
,
"rb"
)
as
f
ile
:
split_dict
=
pickle
.
load
(
f
)
split_dict
=
pickle
.
load
(
f
ile
)
for
key
in
split_dict
.
keys
():
for
key
in
split_dict
.
keys
():
split_dict
[
key
]
=
F
.
zerocopy_from_numpy
(
split_dict
[
key
])
split_dict
[
key
]
=
F
.
zerocopy_from_numpy
(
split_dict
[
key
])
return
split_dict
return
split_dict
...
@@ -259,7 +267,8 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -259,7 +267,8 @@ class PeptidesStructuralDataset(DGLDataset):
Parameters
Parameters
---------
---------
idx : int or tensor
idx : int or tensor
The sample index, if idx is tensor will ignore transform.
The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns
Returns
-------
-------
...
@@ -270,20 +279,25 @@ class PeptidesStructuralDataset(DGLDataset):
...
@@ -270,20 +279,25 @@ class PeptidesStructuralDataset(DGLDataset):
Subset of the dataset at specified indices
Subset of the dataset at specified indices
"""
"""
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
self
.
_transform
is
None
:
return
Subset
(
self
,
idx
.
cpu
())
return
Subset
(
self
,
idx
.
cpu
())
raise
ValueError
(
"Tensor idx not supported when transform is not None."
)
if
self
.
_transform
is
None
:
if
self
.
_transform
is
None
:
return
self
.
graphs
[
idx
],
self
.
labels
[
idx
]
return
self
.
graphs
[
idx
],
self
.
labels
[
idx
]
else
:
return
self
.
_transform
(
self
.
graphs
[
idx
]),
self
.
labels
[
idx
]
return
self
.
_transform
(
self
.
graphs
[
idx
]),
self
.
labels
[
idx
]
class
PeptidesFunctionalDataset
(
DGLDataset
):
class
PeptidesFunctionalDataset
(
DGLDataset
):
r
"""Peptides functional dataset for the graph classification task.
r
"""Peptides functional dataset for the graph classification task.
DGL dataset of
15,535 peptides represented as their molecular graph
DGL dataset of
Peptides-func in the LRGB benchmark which contains
(SMILES) with 10-way multi-task binary classification of their
15,535 peptides represented as their molecular graph(SMILES) with
functional classes.
10-way multi-task binary classification of their
functional classes.
The 10 classes represent the following functional classes (in order):
The 10 classes represent the following functional classes (in order):
['antifungal', 'cell_cell_communication', 'anticancer',
['antifungal', 'cell_cell_communication', 'anticancer',
...
@@ -337,7 +351,8 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -337,7 +351,8 @@ class PeptidesFunctionalDataset(DGLDataset):
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> # get train dataset
>>> # get train dataset
>>> split_dict = dataset.get_idx_split()
>>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]]
>>> trainset = dataset[split_dict["train"]]
...
@@ -364,7 +379,7 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -364,7 +379,7 @@ class PeptidesFunctionalDataset(DGLDataset):
force_reload
=
None
,
force_reload
=
None
,
verbose
=
None
,
verbose
=
None
,
transform
=
None
,
transform
=
None
,
smiles2graph
=
smiles2graph
,
smiles2graph
=
smiles2graph
_OGB
,
):
):
self
.
smiles2graph
=
smiles2graph
self
.
smiles2graph
=
smiles2graph
# MD5 hash of the dataset file.
# MD5 hash of the dataset file.
...
@@ -373,8 +388,10 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -373,8 +388,10 @@ class PeptidesFunctionalDataset(DGLDataset):
https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1
https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1
"""
"""
self
.
md5sum_stratified_split
=
"5a0114bdadc80b94fc7ae974f13ef061"
self
.
md5sum_stratified_split
=
"5a0114bdadc80b94fc7ae974f13ef061"
self
.
graphs
=
[]
self
.
labels
=
[]
super
(
PeptidesFunctionalDataset
,
self
).
__init__
(
super
().
__init__
(
name
=
"Peptides-func"
,
name
=
"Peptides-func"
,
raw_dir
=
raw_dir
,
raw_dir
=
raw_dir
,
url
=
"""
url
=
"""
...
@@ -387,44 +404,50 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -387,44 +404,50 @@ class PeptidesFunctionalDataset(DGLDataset):
@
property
@
property
def
raw_data_path
(
self
):
def
raw_data_path
(
self
):
r
"""Path to save the raw dataset file."""
return
os
.
path
.
join
(
self
.
raw_path
,
"peptide_multi_class_dataset.csv.gz"
)
return
os
.
path
.
join
(
self
.
raw_path
,
"peptide_multi_class_dataset.csv.gz"
)
@
property
@
property
def
split_data_path
(
self
):
def
split_data_path
(
self
):
r
"""Path to save the dataset split file."""
return
os
.
path
.
join
(
return
os
.
path
.
join
(
self
.
raw_path
,
"splits_random_stratified_peptide.pickle"
self
.
raw_path
,
"splits_random_stratified_peptide.pickle"
)
)
@
property
@
property
def
graph_path
(
self
):
def
graph_path
(
self
):
r
"""Path to save the processed dataset file."""
return
os
.
path
.
join
(
self
.
save_path
,
"Peptides-func.bin"
)
return
os
.
path
.
join
(
self
.
save_path
,
"Peptides-func.bin"
)
@
property
@
property
def
num_atom_types
(
self
):
def
num_atom_types
(
self
):
r
"""Number of atom types."""
return
9
return
9
@
property
@
property
def
num_bond_types
(
self
):
def
num_bond_types
(
self
):
r
"""Number of bond types."""
return
3
return
3
@
property
@
property
def
num_classes
(
self
):
def
num_classes
(
self
):
r
"""Number of graph classes."""
return
10
return
10
def
_md5sum
(
self
,
path
):
def
_md5sum
(
self
,
path
):
hash_md5
=
hashlib
.
md5
()
hash_md5
=
hashlib
.
md5
()
with
open
(
path
,
"rb"
)
as
f
:
with
open
(
path
,
"rb"
)
as
f
ile
:
buffer
=
f
.
read
()
buffer
=
f
ile
.
read
()
hash_md5
.
update
(
buffer
)
hash_md5
.
update
(
buffer
)
return
hash_md5
.
hexdigest
()
return
hash_md5
.
hexdigest
()
def
download
(
self
):
def
download
(
self
):
path
=
download
(
self
.
url
,
path
=
self
.
raw_data_path
)
path
=
download
(
self
.
url
,
path
=
self
.
raw_data_path
)
# Save to disk the MD5 hash of the downloaded file.
# Save to disk the MD5 hash of the downloaded file.
hash
=
self
.
_md5sum
(
path
)
hash
_data
=
self
.
_md5sum
(
path
)
if
hash
!=
self
.
md5sum_data
:
if
hash
_data
!=
self
.
md5sum_data
:
raise
ValueError
(
"Unexpected MD5 hash of the downloaded file"
)
raise
ValueError
(
"Unexpected MD5 hash of the downloaded file"
)
open
(
os
.
path
.
join
(
self
.
raw_path
,
hash
),
"w"
).
close
()
open
(
os
.
path
.
join
(
self
.
raw_path
,
hash
_data
),
"w"
).
close
()
# Download train/val/test splits.
# Download train/val/test splits.
path_split
=
download
(
path_split
=
download
(
self
.
url_stratified_split
,
path
=
self
.
split_data_path
self
.
url_stratified_split
,
path
=
self
.
split_data_path
...
@@ -438,8 +461,7 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -438,8 +461,7 @@ class PeptidesFunctionalDataset(DGLDataset):
smiles_list
=
data_df
[
"smiles"
]
smiles_list
=
data_df
[
"smiles"
]
if
self
.
verbose
:
if
self
.
verbose
:
print
(
"Converting SMILES strings into graphs..."
)
print
(
"Converting SMILES strings into graphs..."
)
self
.
graphs
=
[]
self
.
labels
=
[]
for
i
in
tqdm
(
range
(
len
(
smiles_list
))):
for
i
in
tqdm
(
range
(
len
(
smiles_list
))):
smiles
=
smiles_list
[
i
]
smiles
=
smiles_list
[
i
]
graph
=
self
.
smiles2graph
(
smiles
)
graph
=
self
.
smiles2graph
(
smiles
)
...
@@ -478,8 +500,8 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -478,8 +500,8 @@ class PeptidesFunctionalDataset(DGLDataset):
Returns:
Returns:
Dict with 'train', 'val', 'test', splits indices.
Dict with 'train', 'val', 'test', splits indices.
"""
"""
with
open
(
self
.
split_data_path
,
"rb"
)
as
f
:
with
open
(
self
.
split_data_path
,
"rb"
)
as
f
ile
:
split_dict
=
pickle
.
load
(
f
)
split_dict
=
pickle
.
load
(
f
ile
)
for
key
in
split_dict
.
keys
():
for
key
in
split_dict
.
keys
():
split_dict
[
key
]
=
F
.
zerocopy_from_numpy
(
split_dict
[
key
])
split_dict
[
key
]
=
F
.
zerocopy_from_numpy
(
split_dict
[
key
])
return
split_dict
return
split_dict
...
@@ -493,7 +515,8 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -493,7 +515,8 @@ class PeptidesFunctionalDataset(DGLDataset):
Parameters
Parameters
---------
---------
idx : int or tensor
idx : int or tensor
The sample index, if idx is tensor will ignore transform.
The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns
Returns
-------
-------
...
@@ -504,19 +527,24 @@ class PeptidesFunctionalDataset(DGLDataset):
...
@@ -504,19 +527,24 @@ class PeptidesFunctionalDataset(DGLDataset):
Subset of the dataset at specified indices
Subset of the dataset at specified indices
"""
"""
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
self
.
_transform
is
None
:
return
Subset
(
self
,
idx
.
cpu
())
return
Subset
(
self
,
idx
.
cpu
())
raise
ValueError
(
"Tensor idx not supported when transform is not None."
)
if
self
.
_transform
is
None
:
if
self
.
_transform
is
None
:
return
self
.
graphs
[
idx
],
self
.
labels
[
idx
]
return
self
.
graphs
[
idx
],
self
.
labels
[
idx
]
else
:
return
self
.
_transform
(
self
.
graphs
[
idx
]),
self
.
labels
[
idx
]
return
self
.
_transform
(
self
.
graphs
[
idx
]),
self
.
labels
[
idx
]
class
VOCSuperpixelsDataset
(
DGLDataset
):
class
VOCSuperpixelsDataset
(
DGLDataset
):
r
"""VOCSuperpixels dataset for the node classification task.
r
"""VOCSuperpixels dataset for the node classification task.
DGL dataset of Pascal
VOC
Superpixels
which contains image
superpixels
DGL dataset of PascalVOC
-SP in the LRGB benchmark
which contains image
and a semantic segmentation label for each node superpixel.
superpixels
and a semantic segmentation label for each node superpixel.
color map
color map
0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle,
0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle,
...
@@ -545,14 +573,16 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -545,14 +573,16 @@ class VOCSuperpixelsDataset(DGLDataset):
construct_format : str, optional
construct_format : str, optional
Option to select the graph construction format.
Option to select the graph construction format.
Should be chosen from the following formats:
Should be chosen from the following formats:
"edge_wt_only_coord": the graphs are 8-nn graphs with the edge weights
- "edge_wt_only_coord": the graphs are 8-nn graphs with the edge weights
computed based on only spatial coordinates of superpixel nodes.
computed based on only spatial coordinates of superpixel nodes.
"edge_wt_coord_feat": the graphs are 8-nn graphs with the edge weights
-
"edge_wt_coord_feat": the graphs are 8-nn graphs with the edge weights
computed based on combination of spatial coordinates and feature
computed based on combination of spatial coordinates and feature
values of superpixel nodes.
values of superpixel nodes.
"edge_wt_region_boundary": the graphs region boundary graphs where two
- "edge_wt_region_boundary": the graphs region boundary graphs where two
regions (i.e. superpixel nodes) have an edge between them if they share
regions (i.e. superpixel nodes) have an edge between them if they
a boundary in the original image.
share a boundary in the original image.
Default: "edge_wt_region_boundary".
Default: "edge_wt_region_boundary".
slic_compactness : int, optional
slic_compactness : int, optional
Option to select compactness of slic that was used for superpixels
Option to select compactness of slic that was used for superpixels
...
@@ -581,16 +611,19 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -581,16 +611,19 @@ class VOCSuperpixelsDataset(DGLDataset):
>>> graph = train_dataset[0]
>>> graph = train_dataset[0]
>>> graph
>>> graph
Graph(num_nodes=460, num_edges=2632,
Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)}
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),
'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> import torch
>>> import torch
>>> idx = torch.tensor([0, 1, 2])
>>> idx = torch.tensor([0, 1, 2])
>>> train_dataset_subset = train_dataset[idx]
>>> train_dataset_subset = train_dataset[idx]
>>> train_dataset_subset[0]
>>> train_dataset_subset[0]
Graph(num_nodes=460, num_edges=2632,
Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)}
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),
'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
"""
"""
...
@@ -629,8 +662,6 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -629,8 +662,6 @@ class VOCSuperpixelsDataset(DGLDataset):
verbose
=
None
,
verbose
=
None
,
transform
=
None
,
transform
=
None
,
):
):
self
.
construct_format
=
construct_format
self
.
slic_compactness
=
slic_compactness
assert
split
in
[
"train"
,
"val"
,
"test"
],
"split not valid."
assert
split
in
[
"train"
,
"val"
,
"test"
],
"split not valid."
assert
construct_format
in
[
assert
construct_format
in
[
"edge_wt_only_coord"
,
"edge_wt_only_coord"
,
...
@@ -638,8 +669,13 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -638,8 +669,13 @@ class VOCSuperpixelsDataset(DGLDataset):
"edge_wt_region_boundary"
,
"edge_wt_region_boundary"
,
],
"construct_format not valid."
],
"construct_format not valid."
assert
slic_compactness
in
[
10
,
30
],
"slic_compactness not valid."
assert
slic_compactness
in
[
10
,
30
],
"slic_compactness not valid."
self
.
construct_format
=
construct_format
self
.
slic_compactness
=
slic_compactness
self
.
split
=
split
self
.
split
=
split
super
(
VOCSuperpixelsDataset
,
self
).
__init__
(
self
.
graphs
=
[]
super
().
__init__
(
name
=
"PascalVOC-SP"
,
name
=
"PascalVOC-SP"
,
raw_dir
=
raw_dir
,
raw_dir
=
raw_dir
,
url
=
self
.
urls
[
self
.
slic_compactness
][
self
.
construct_format
],
url
=
self
.
urls
[
self
.
slic_compactness
][
self
.
construct_format
],
...
@@ -650,6 +686,7 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -650,6 +686,7 @@ class VOCSuperpixelsDataset(DGLDataset):
@
property
@
property
def
save_path
(
self
):
def
save_path
(
self
):
r
"""Directory to save the processed dataset."""
return
os
.
path
.
join
(
return
os
.
path
.
join
(
self
.
raw_path
,
self
.
raw_path
,
"slic_compactness_"
+
str
(
self
.
slic_compactness
),
"slic_compactness_"
+
str
(
self
.
slic_compactness
),
...
@@ -658,10 +695,12 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -658,10 +695,12 @@ class VOCSuperpixelsDataset(DGLDataset):
@
property
@
property
def
raw_data_path
(
self
):
def
raw_data_path
(
self
):
r
"""Path to save the raw dataset file."""
return
os
.
path
.
join
(
self
.
save_path
,
f
"
{
self
.
split
}
.pickle"
)
return
os
.
path
.
join
(
self
.
save_path
,
f
"
{
self
.
split
}
.pickle"
)
@
property
@
property
def
graph_path
(
self
):
def
graph_path
(
self
):
r
"""Path to save the processed dataset file."""
return
os
.
path
.
join
(
self
.
save_path
,
f
"processed_
{
self
.
split
}
.pkl"
)
return
os
.
path
.
join
(
self
.
save_path
,
f
"processed_
{
self
.
split
}
.pkl"
)
@
property
@
property
...
@@ -689,10 +728,9 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -689,10 +728,9 @@ class VOCSuperpixelsDataset(DGLDataset):
os
.
unlink
(
path
)
os
.
unlink
(
path
)
def
process
(
self
):
def
process
(
self
):
with
open
(
self
.
raw_data_path
,
"rb"
)
as
f
:
with
open
(
self
.
raw_data_path
,
"rb"
)
as
f
ile
:
graphs
=
pickle
.
load
(
f
)
graphs
=
pickle
.
load
(
f
ile
)
self
.
graphs
=
[]
for
idx
in
tqdm
(
for
idx
in
tqdm
(
range
(
len
(
graphs
)),
desc
=
f
"Processing
{
self
.
split
}
dataset"
range
(
len
(
graphs
)),
desc
=
f
"Processing
{
self
.
split
}
dataset"
):
):
...
@@ -715,13 +753,13 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -715,13 +753,13 @@ class VOCSuperpixelsDataset(DGLDataset):
self
.
graphs
.
append
(
DGLgraph
)
self
.
graphs
.
append
(
DGLgraph
)
def
load
(
self
):
def
load
(
self
):
with
open
(
self
.
graph_path
,
"rb"
)
as
f
:
with
open
(
self
.
graph_path
,
"rb"
)
as
f
ile
:
f
=
pickle
.
load
(
f
)
graphs
=
pickle
.
load
(
f
ile
)
self
.
graphs
=
f
self
.
graphs
=
graphs
def
save
(
self
):
def
save
(
self
):
with
open
(
os
.
path
.
join
(
self
.
graph_path
),
"wb"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
graph_path
),
"wb"
)
as
f
ile
:
pickle
.
dump
(
self
.
graphs
,
f
)
pickle
.
dump
(
self
.
graphs
,
f
ile
)
def
has_cache
(
self
):
def
has_cache
(
self
):
return
os
.
path
.
exists
(
self
.
graph_path
)
return
os
.
path
.
exists
(
self
.
graph_path
)
...
@@ -732,7 +770,8 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -732,7 +770,8 @@ class VOCSuperpixelsDataset(DGLDataset):
Parameters
Parameters
---------
---------
idx : int or tensor
idx : int or tensor
The sample index, if idx is tensor will ignore transform.
The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns
Returns
-------
-------
...
@@ -747,9 +786,14 @@ class VOCSuperpixelsDataset(DGLDataset):
...
@@ -747,9 +786,14 @@ class VOCSuperpixelsDataset(DGLDataset):
Subset of the dataset at specified indices
Subset of the dataset at specified indices
"""
"""
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
F
.
is_tensor
(
idx
)
and
idx
.
dim
()
==
1
:
if
self
.
_transform
is
None
:
return
Subset
(
self
,
idx
.
cpu
())
return
Subset
(
self
,
idx
.
cpu
())
raise
ValueError
(
"Tensor idx not supported when transform is not None."
)
if
self
.
_transform
is
None
:
if
self
.
_transform
is
None
:
return
self
.
graphs
[
idx
]
return
self
.
graphs
[
idx
]
else
:
return
self
.
_transform
(
self
.
graphs
[
idx
])
return
self
.
_transform
(
self
.
graphs
[
idx
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment