"...api/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "3f240fbb3734ab5f112a3d26d3856cf0a0e1a092"
Unverified Commit ad4df9c5 authored by paoxiaode's avatar paoxiaode Committed by GitHub
Browse files

[Dataset] update docstring of LRGB (#6430)

parent ec5c515c
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import pickle import pickle
import pandas as pd import pandas as pd
from ogb.utils import smiles2graph from ogb.utils import smiles2graph as smiles2graph_OGB
from tqdm import tqdm from tqdm import tqdm
from .. import backend as F from .. import backend as F
...@@ -23,22 +23,23 @@ from .utils import ( ...@@ -23,22 +23,23 @@ from .utils import (
class PeptidesStructuralDataset(DGLDataset): class PeptidesStructuralDataset(DGLDataset):
r"""Peptides structure dataset for the graph regression task. r"""Peptides structure dataset for the graph regression task.
DGL dataset of 15,535 small peptides represented as their molecular DGL dataset of Peptides-struct in the LRGB benchmark which contains
graph (SMILES) with 11 regression targets derived from the peptide's 15,535 small peptides represented as their molecular graph (SMILES)
3D structure. with 11 regression targets derived from the peptide's 3D structure.
The 11 regression targets were precomputed from molecules' 3D structure: The 11 regression targets were precomputed from molecules' 3D structure:
Inertia_mass_[a-c]: The principal component of the inertia of the
mass, with some normalizations. (Sorted) - Inertia_mass_[a-c]: The principal component of the inertia of the
Inertia_valence_[a-c]: The principal component of the inertia of the mass, with some normalizations. (Sorted)
Hydrogen atoms. This is basically a measure of the 3D - Inertia_valence_[a-c]: The principal component of the inertia of the
distribution of hydrogens. (Sorted) Hydrogen atoms. This is basically a measure of the 3D
length_[a-c]: The length around the 3 main geometric axis of distribution of hydrogens. (Sorted)
the 3D objects (without considering atom types). (Sorted) - length_[a-c]: The length around the 3 main geometric axis of
Spherocity: SpherocityIndex descriptor computed by the 3D objects (without considering atom types). (Sorted)
rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex - Spherocity: SpherocityIndex descriptor computed by
Plane_best_fit: Plane of best fit (PBF) descriptor computed by rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex
rdkit.Chem.rdMolDescriptors.CalcPBF - Plane_best_fit: Plane of best fit (PBF) descriptor computed by
rdkit.Chem.rdMolDescriptors.CalcPBF
Reference `<https://arxiv.org/abs/2206.08164.pdf>`_ Reference `<https://arxiv.org/abs/2206.08164.pdf>`_
...@@ -87,7 +88,8 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -87,7 +88,8 @@ class PeptidesStructuralDataset(DGLDataset):
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}) edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter >>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> # get train dataset >>> # get train dataset
>>> split_dict = dataset.get_idx_split() >>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]] >>> trainset = dataset[split_dict["train"]]
...@@ -114,7 +116,7 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -114,7 +116,7 @@ class PeptidesStructuralDataset(DGLDataset):
force_reload=None, force_reload=None,
verbose=None, verbose=None,
transform=None, transform=None,
smiles2graph=smiles2graph, smiles2graph=smiles2graph_OGB,
): ):
self.smiles2graph = smiles2graph self.smiles2graph = smiles2graph
# MD5 hash of the dataset file. # MD5 hash of the dataset file.
...@@ -123,8 +125,10 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -123,8 +125,10 @@ class PeptidesStructuralDataset(DGLDataset):
https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1 https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1
""" """
self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061" self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061"
self.graphs = []
self.labels = []
super(PeptidesStructuralDataset, self).__init__( super().__init__(
name="Peptides-struc", name="Peptides-struc",
raw_dir=raw_dir, raw_dir=raw_dir,
url=""" url="""
...@@ -137,40 +141,45 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -137,40 +141,45 @@ class PeptidesStructuralDataset(DGLDataset):
@property @property
def raw_data_path(self): def raw_data_path(self):
r"""Path to save the raw dataset file."""
return os.path.join(self.raw_path, "peptide_structure_dataset.csv.gz") return os.path.join(self.raw_path, "peptide_structure_dataset.csv.gz")
@property @property
def split_data_path(self): def split_data_path(self):
r"""Path to save the dataset split file."""
return os.path.join( return os.path.join(
self.raw_path, "splits_random_stratified_peptide_structure.pickle" self.raw_path, "splits_random_stratified_peptide_structure.pickle"
) )
@property @property
def graph_path(self): def graph_path(self):
r"""Path to save the processed dataset file."""
return os.path.join(self.save_path, "Peptides-struc.bin") return os.path.join(self.save_path, "Peptides-struc.bin")
@property @property
def num_atom_types(self): def num_atom_types(self):
r"""Number of atom types."""
return 9 return 9
@property @property
def num_bond_types(self): def num_bond_types(self):
r"""Number of bond types."""
return 3 return 3
def _md5sum(self, path): def _md5sum(self, path):
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
with open(path, "rb") as f: with open(path, "rb") as file:
buffer = f.read() buffer = file.read()
hash_md5.update(buffer) hash_md5.update(buffer)
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download(self): def download(self):
path = download(self.url, path=self.raw_data_path) path = download(self.url, path=self.raw_data_path)
# Save to disk the MD5 hash of the downloaded file. # Save to disk the MD5 hash of the downloaded file.
hash = self._md5sum(path) hash_data = self._md5sum(path)
if hash != self.md5sum_data: if hash_data != self.md5sum_data:
raise ValueError("Unexpected MD5 hash of the downloaded file") raise ValueError("Unexpected MD5 hash of the downloaded file")
open(os.path.join(self.raw_path, hash), "w").close() open(os.path.join(self.raw_path, hash_data), "w").close()
# Download train/val/test splits. # Download train/val/test splits.
path_split = download( path_split = download(
self.url_stratified_split, path=self.split_data_path self.url_stratified_split, path=self.split_data_path
...@@ -201,8 +210,7 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -201,8 +210,7 @@ class PeptidesStructuralDataset(DGLDataset):
) )
if self.verbose: if self.verbose:
print("Converting SMILES strings into graphs...") print("Converting SMILES strings into graphs...")
self.graphs = []
self.labels = []
for i in tqdm(range(len(smiles_list))): for i in tqdm(range(len(smiles_list))):
smiles = smiles_list[i] smiles = smiles_list[i]
y = data_df.iloc[i][target_names] y = data_df.iloc[i][target_names]
...@@ -244,8 +252,8 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -244,8 +252,8 @@ class PeptidesStructuralDataset(DGLDataset):
Returns: Returns:
Dict with 'train', 'val', 'test', splits indices. Dict with 'train', 'val', 'test', splits indices.
""" """
with open(self.split_data_path, "rb") as f: with open(self.split_data_path, "rb") as file:
split_dict = pickle.load(f) split_dict = pickle.load(file)
for key in split_dict.keys(): for key in split_dict.keys():
split_dict[key] = F.zerocopy_from_numpy(split_dict[key]) split_dict[key] = F.zerocopy_from_numpy(split_dict[key])
return split_dict return split_dict
...@@ -259,7 +267,8 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -259,7 +267,8 @@ class PeptidesStructuralDataset(DGLDataset):
Parameters Parameters
--------- ---------
idx : int or tensor idx : int or tensor
The sample index, if idx is tensor will ignore transform. The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns Returns
------- -------
...@@ -270,20 +279,25 @@ class PeptidesStructuralDataset(DGLDataset): ...@@ -270,20 +279,25 @@ class PeptidesStructuralDataset(DGLDataset):
Subset of the dataset at specified indices Subset of the dataset at specified indices
""" """
if F.is_tensor(idx) and idx.dim() == 1: if F.is_tensor(idx) and idx.dim() == 1:
return Subset(self, idx.cpu()) if self._transform is None:
return Subset(self, idx.cpu())
raise ValueError(
"Tensor idx not supported when transform is not None."
)
if self._transform is None: if self._transform is None:
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
else:
return self._transform(self.graphs[idx]), self.labels[idx] return self._transform(self.graphs[idx]), self.labels[idx]
class PeptidesFunctionalDataset(DGLDataset): class PeptidesFunctionalDataset(DGLDataset):
r"""Peptides functional dataset for the graph classification task. r"""Peptides functional dataset for the graph classification task.
DGL dataset of 15,535 peptides represented as their molecular graph DGL dataset of Peptides-func in the LRGB benchmark which contains
(SMILES) with 10-way multi-task binary classification of their 15,535 peptides represented as their molecular graph(SMILES) with
functional classes. 10-way multi-task binary classification of their functional classes.
The 10 classes represent the following functional classes (in order): The 10 classes represent the following functional classes (in order):
['antifungal', 'cell_cell_communication', 'anticancer', ['antifungal', 'cell_cell_communication', 'anticancer',
...@@ -337,7 +351,8 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -337,7 +351,8 @@ class PeptidesFunctionalDataset(DGLDataset):
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}) edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter >>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> # get train dataset >>> # get train dataset
>>> split_dict = dataset.get_idx_split() >>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]] >>> trainset = dataset[split_dict["train"]]
...@@ -364,7 +379,7 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -364,7 +379,7 @@ class PeptidesFunctionalDataset(DGLDataset):
force_reload=None, force_reload=None,
verbose=None, verbose=None,
transform=None, transform=None,
smiles2graph=smiles2graph, smiles2graph=smiles2graph_OGB,
): ):
self.smiles2graph = smiles2graph self.smiles2graph = smiles2graph
# MD5 hash of the dataset file. # MD5 hash of the dataset file.
...@@ -373,8 +388,10 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -373,8 +388,10 @@ class PeptidesFunctionalDataset(DGLDataset):
https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1 https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1
""" """
self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061" self.md5sum_stratified_split = "5a0114bdadc80b94fc7ae974f13ef061"
self.graphs = []
self.labels = []
super(PeptidesFunctionalDataset, self).__init__( super().__init__(
name="Peptides-func", name="Peptides-func",
raw_dir=raw_dir, raw_dir=raw_dir,
url=""" url="""
...@@ -387,44 +404,50 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -387,44 +404,50 @@ class PeptidesFunctionalDataset(DGLDataset):
@property @property
def raw_data_path(self): def raw_data_path(self):
r"""Path to save the raw dataset file."""
return os.path.join(self.raw_path, "peptide_multi_class_dataset.csv.gz") return os.path.join(self.raw_path, "peptide_multi_class_dataset.csv.gz")
@property @property
def split_data_path(self): def split_data_path(self):
r"""Path to save the dataset split file."""
return os.path.join( return os.path.join(
self.raw_path, "splits_random_stratified_peptide.pickle" self.raw_path, "splits_random_stratified_peptide.pickle"
) )
@property @property
def graph_path(self): def graph_path(self):
r"""Path to save the processed dataset file."""
return os.path.join(self.save_path, "Peptides-func.bin") return os.path.join(self.save_path, "Peptides-func.bin")
@property @property
def num_atom_types(self): def num_atom_types(self):
r"""Number of atom types."""
return 9 return 9
@property @property
def num_bond_types(self): def num_bond_types(self):
r"""Number of bond types."""
return 3 return 3
@property @property
def num_classes(self): def num_classes(self):
r"""Number of graph classes."""
return 10 return 10
def _md5sum(self, path): def _md5sum(self, path):
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
with open(path, "rb") as f: with open(path, "rb") as file:
buffer = f.read() buffer = file.read()
hash_md5.update(buffer) hash_md5.update(buffer)
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download(self): def download(self):
path = download(self.url, path=self.raw_data_path) path = download(self.url, path=self.raw_data_path)
# Save to disk the MD5 hash of the downloaded file. # Save to disk the MD5 hash of the downloaded file.
hash = self._md5sum(path) hash_data = self._md5sum(path)
if hash != self.md5sum_data: if hash_data != self.md5sum_data:
raise ValueError("Unexpected MD5 hash of the downloaded file") raise ValueError("Unexpected MD5 hash of the downloaded file")
open(os.path.join(self.raw_path, hash), "w").close() open(os.path.join(self.raw_path, hash_data), "w").close()
# Download train/val/test splits. # Download train/val/test splits.
path_split = download( path_split = download(
self.url_stratified_split, path=self.split_data_path self.url_stratified_split, path=self.split_data_path
...@@ -438,8 +461,7 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -438,8 +461,7 @@ class PeptidesFunctionalDataset(DGLDataset):
smiles_list = data_df["smiles"] smiles_list = data_df["smiles"]
if self.verbose: if self.verbose:
print("Converting SMILES strings into graphs...") print("Converting SMILES strings into graphs...")
self.graphs = []
self.labels = []
for i in tqdm(range(len(smiles_list))): for i in tqdm(range(len(smiles_list))):
smiles = smiles_list[i] smiles = smiles_list[i]
graph = self.smiles2graph(smiles) graph = self.smiles2graph(smiles)
...@@ -478,8 +500,8 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -478,8 +500,8 @@ class PeptidesFunctionalDataset(DGLDataset):
Returns: Returns:
Dict with 'train', 'val', 'test', splits indices. Dict with 'train', 'val', 'test', splits indices.
""" """
with open(self.split_data_path, "rb") as f: with open(self.split_data_path, "rb") as file:
split_dict = pickle.load(f) split_dict = pickle.load(file)
for key in split_dict.keys(): for key in split_dict.keys():
split_dict[key] = F.zerocopy_from_numpy(split_dict[key]) split_dict[key] = F.zerocopy_from_numpy(split_dict[key])
return split_dict return split_dict
...@@ -493,7 +515,8 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -493,7 +515,8 @@ class PeptidesFunctionalDataset(DGLDataset):
Parameters Parameters
--------- ---------
idx : int or tensor idx : int or tensor
The sample index, if idx is tensor will ignore transform. The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns Returns
------- -------
...@@ -504,19 +527,24 @@ class PeptidesFunctionalDataset(DGLDataset): ...@@ -504,19 +527,24 @@ class PeptidesFunctionalDataset(DGLDataset):
Subset of the dataset at specified indices Subset of the dataset at specified indices
""" """
if F.is_tensor(idx) and idx.dim() == 1: if F.is_tensor(idx) and idx.dim() == 1:
return Subset(self, idx.cpu()) if self._transform is None:
return Subset(self, idx.cpu())
raise ValueError(
"Tensor idx not supported when transform is not None."
)
if self._transform is None: if self._transform is None:
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
else:
return self._transform(self.graphs[idx]), self.labels[idx] return self._transform(self.graphs[idx]), self.labels[idx]
class VOCSuperpixelsDataset(DGLDataset): class VOCSuperpixelsDataset(DGLDataset):
r"""VOCSuperpixels dataset for the node classification task. r"""VOCSuperpixels dataset for the node classification task.
DGL dataset of Pascal VOC Superpixels which contains image superpixels DGL dataset of PascalVOC-SP in the LRGB benchmark which contains image
and a semantic segmentation label for each node superpixel. superpixels and a semantic segmentation label for each node superpixel.
color map color map
0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle, 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle,
...@@ -545,14 +573,16 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -545,14 +573,16 @@ class VOCSuperpixelsDataset(DGLDataset):
construct_format : str, optional construct_format : str, optional
Option to select the graph construction format. Option to select the graph construction format.
Should be chosen from the following formats: Should be chosen from the following formats:
"edge_wt_only_coord": the graphs are 8-nn graphs with the edge weights
computed based on only spatial coordinates of superpixel nodes. - "edge_wt_only_coord": the graphs are 8-nn graphs with the edge weights
"edge_wt_coord_feat": the graphs are 8-nn graphs with the edge weights computed based on only spatial coordinates of superpixel nodes.
computed based on combination of spatial coordinates and feature - "edge_wt_coord_feat": the graphs are 8-nn graphs with the edge weights
values of superpixel nodes. computed based on combination of spatial coordinates and feature
"edge_wt_region_boundary": the graphs region boundary graphs where two values of superpixel nodes.
regions (i.e. superpixel nodes) have an edge between them if they share - "edge_wt_region_boundary": the graphs region boundary graphs where two
a boundary in the original image. regions (i.e. superpixel nodes) have an edge between them if they
share a boundary in the original image.
Default: "edge_wt_region_boundary". Default: "edge_wt_region_boundary".
slic_compactness : int, optional slic_compactness : int, optional
Option to select compactness of slic that was used for superpixels Option to select compactness of slic that was used for superpixels
...@@ -581,16 +611,19 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -581,16 +611,19 @@ class VOCSuperpixelsDataset(DGLDataset):
>>> graph = train_dataset[0] >>> graph = train_dataset[0]
>>> graph >>> graph
Graph(num_nodes=460, num_edges=2632, Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)} ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),
'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)}) edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
>>> # accept tensor to be index, but will ignore transform parameter >>> # support tensor to be index when transform is None
>>> # see details in __getitem__ function
>>> import torch >>> import torch
>>> idx = torch.tensor([0, 1, 2]) >>> idx = torch.tensor([0, 1, 2])
>>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset = train_dataset[idx]
>>> train_dataset_subset[0] >>> train_dataset_subset[0]
Graph(num_nodes=460, num_edges=2632, Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)} ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),
'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)}) edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
""" """
...@@ -629,8 +662,6 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -629,8 +662,6 @@ class VOCSuperpixelsDataset(DGLDataset):
verbose=None, verbose=None,
transform=None, transform=None,
): ):
self.construct_format = construct_format
self.slic_compactness = slic_compactness
assert split in ["train", "val", "test"], "split not valid." assert split in ["train", "val", "test"], "split not valid."
assert construct_format in [ assert construct_format in [
"edge_wt_only_coord", "edge_wt_only_coord",
...@@ -638,8 +669,13 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -638,8 +669,13 @@ class VOCSuperpixelsDataset(DGLDataset):
"edge_wt_region_boundary", "edge_wt_region_boundary",
], "construct_format not valid." ], "construct_format not valid."
assert slic_compactness in [10, 30], "slic_compactness not valid." assert slic_compactness in [10, 30], "slic_compactness not valid."
self.construct_format = construct_format
self.slic_compactness = slic_compactness
self.split = split self.split = split
super(VOCSuperpixelsDataset, self).__init__( self.graphs = []
super().__init__(
name="PascalVOC-SP", name="PascalVOC-SP",
raw_dir=raw_dir, raw_dir=raw_dir,
url=self.urls[self.slic_compactness][self.construct_format], url=self.urls[self.slic_compactness][self.construct_format],
...@@ -650,6 +686,7 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -650,6 +686,7 @@ class VOCSuperpixelsDataset(DGLDataset):
@property @property
def save_path(self): def save_path(self):
r"""Directory to save the processed dataset."""
return os.path.join( return os.path.join(
self.raw_path, self.raw_path,
"slic_compactness_" + str(self.slic_compactness), "slic_compactness_" + str(self.slic_compactness),
...@@ -658,10 +695,12 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -658,10 +695,12 @@ class VOCSuperpixelsDataset(DGLDataset):
@property @property
def raw_data_path(self): def raw_data_path(self):
r"""Path to save the raw dataset file."""
return os.path.join(self.save_path, f"{self.split}.pickle") return os.path.join(self.save_path, f"{self.split}.pickle")
@property @property
def graph_path(self): def graph_path(self):
r"""Path to save the processed dataset file."""
return os.path.join(self.save_path, f"processed_{self.split}.pkl") return os.path.join(self.save_path, f"processed_{self.split}.pkl")
@property @property
...@@ -689,10 +728,9 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -689,10 +728,9 @@ class VOCSuperpixelsDataset(DGLDataset):
os.unlink(path) os.unlink(path)
def process(self): def process(self):
with open(self.raw_data_path, "rb") as f: with open(self.raw_data_path, "rb") as file:
graphs = pickle.load(f) graphs = pickle.load(file)
self.graphs = []
for idx in tqdm( for idx in tqdm(
range(len(graphs)), desc=f"Processing {self.split} dataset" range(len(graphs)), desc=f"Processing {self.split} dataset"
): ):
...@@ -715,13 +753,13 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -715,13 +753,13 @@ class VOCSuperpixelsDataset(DGLDataset):
self.graphs.append(DGLgraph) self.graphs.append(DGLgraph)
def load(self): def load(self):
with open(self.graph_path, "rb") as f: with open(self.graph_path, "rb") as file:
f = pickle.load(f) graphs = pickle.load(file)
self.graphs = f self.graphs = graphs
def save(self): def save(self):
with open(os.path.join(self.graph_path), "wb") as f: with open(os.path.join(self.graph_path), "wb") as file:
pickle.dump(self.graphs, f) pickle.dump(self.graphs, file)
def has_cache(self): def has_cache(self):
return os.path.exists(self.graph_path) return os.path.exists(self.graph_path)
...@@ -732,7 +770,8 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -732,7 +770,8 @@ class VOCSuperpixelsDataset(DGLDataset):
Parameters Parameters
--------- ---------
idx : int or tensor idx : int or tensor
The sample index, if idx is tensor will ignore transform. The sample index.
1-D tensor as `idx` is allowed when transform is None.
Returns Returns
------- -------
...@@ -747,9 +786,14 @@ class VOCSuperpixelsDataset(DGLDataset): ...@@ -747,9 +786,14 @@ class VOCSuperpixelsDataset(DGLDataset):
Subset of the dataset at specified indices Subset of the dataset at specified indices
""" """
if F.is_tensor(idx) and idx.dim() == 1: if F.is_tensor(idx) and idx.dim() == 1:
return Subset(self, idx.cpu()) if self._transform is None:
return Subset(self, idx.cpu())
raise ValueError(
"Tensor idx not supported when transform is not None."
)
if self._transform is None: if self._transform is None:
return self.graphs[idx] return self.graphs[idx]
else:
return self._transform(self.graphs[idx]) return self._transform(self.graphs[idx])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment