Unverified Commit d22f96fe authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] QM7bDataset (#1915)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* QM7bDataset

* Update qm7b.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent a4c931a9
......@@ -16,7 +16,7 @@ from .gindt import GINDataset
from .bitcoinotc import BitcoinOTC
from .gdelt import GDELT
from .icews18 import ICEWS18
from .qm7b import QM7b
from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
......
"""QM7b dataset for graph property prediction (regression)."""
from scipy import io
import numpy as np
import os
from .utils import get_download_dir, download
from ..utils import retry_method_with_fix
from .. import convert
from .dgl_dataset import DGLDataset
from .utils import download, save_graphs, load_graphs, \
check_sha1, deprecate_property
from .. import backend as F
from ..convert import graph as dgl_graph
class QM7bDataset(DGLDataset):
r"""QM7b dataset for graph property prediction (regression)
class QM7b(object):
"""
This dataset consists of 7,211 molecules with 14 regression targets.
Nodes means atoms and edges means bonds. Edge data 'h' means
Nodes means atoms and edges means bonds. Edge data 'h' means
the entry of Coulomb matrix.
Reference: http://quantum-machine.org/datasets/
Statistics
----------
Number of graphs: 7,211
Number of regression targets: 14
Average number of nodes: 15
Average number of edges: 245
Edge feature size: 1
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Reference:
- `QM7b Dataset <http://quantum-machine.org/datasets/>`_
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM7bDataset()
>>> data.num_labels
14
>>>
>>> # iterate over the dataset
>>> for g, label in data:
... edge_feat = g.edata['h'] # get edge feature
... # your code here...
...
>>>
"""
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self):
self.dir = get_download_dir()
self.path = os.path.join(self.dir, 'qm7b', "qm7b.mat")
self.graphs = []
self._load(self.path)
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def _download(self):
download(self._url, path=self.path)
def process(self):
mat_path = self.raw_path + '.mat'
if not check_sha1(mat_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
self.graphs, self.label = self._load_graph(mat_path)
@retry_method_with_fix(_download)
def _load(self, filename):
data = io.loadmat(self.path)
labels = data['T']
def _load_graph(self, filename):
data = io.loadmat(filename)
labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])
feats = data['X']
num_graphs = labels.shape[0]
self.label = labels
graphs = []
for i in range(num_graphs):
edge_list = feats[i].nonzero()
g = convert.graph(edge_list)
g.edata['h'] = feats[i][edge_list[0], edge_list[1]].reshape(-1, 1)
self.graphs.append(g)
g = dgl_graph(edge_list)
g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
dtype=F.data_type_dict['float32'])
graphs.append(g)
return graphs, labels
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
save_graphs(str(graph_path), self.graphs, {'labels': self.label})
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
return os.path.exists(graph_path)
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph.bin'))
self.graphs = graphs
self.label = label_dict['labels']
def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat')
download(self.url, path=file_path)
@property
def num_labels(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
r"""Number of graphs in the dataset"""
return len(self.graphs)
QM7b = QM7bDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment