Unverified Commit a1f59c3b authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

[Feature] QM9Edge Dataset Support (#2704)



* [Feature] Support QM9Edge Datset

* Update qm9_edge.py

* disable tqdm

* Update qm9_edge.py

* Update qm9_edge.py

* Update qm9_edge.py

* Update qm9_edge.py

* Update qm9_edge.py

* Update qm9_edge.py

* Update qm9_edge.py

* remove preprocessing part

* add comparisons in qm9.py

* [docs] add qm9edge dataset
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 62dd1c86
...@@ -175,6 +175,13 @@ QM9 dataset ...@@ -175,6 +175,13 @@ QM9 dataset
.. autoclass:: QM9Dataset .. autoclass:: QM9Dataset
:members: __getitem__, __len__ :members: __getitem__, __len__
.. _qm9edgedata:
QM9Edge dataset
```````````````````````````````````
.. autoclass:: QM9EdgeDataset
:members: __getitem__, __len__
.. _minigcdataset: .. _minigcdataset:
Mini graph classification dataset Mini graph classification dataset
......
...@@ -22,6 +22,7 @@ from .gdelt import GDELT, GDELTDataset ...@@ -22,6 +22,7 @@ from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .qm9 import QM9, QM9Dataset from .qm9 import QM9, QM9Dataset
from .qm9_edge import QM9Edge, QM9EdgeDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
......
...@@ -12,15 +12,20 @@ from .. import backend as F ...@@ -12,15 +12,20 @@ from .. import backend as F
class QM9Dataset(DGLDataset): class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression) r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 13,0831 molecules with 12 regression targets. This dataset consists of 130,831 molecules with 12 regression targets.
Node means atom and edge means bond. Nodes correspond to atoms and edges correspond to close atom pairs.
This dataset differs from :class:`~dgl.data.QM9EdgeDataset` in the following aspects:
1. Edges in this dataset are purely distance-based.
2. It only provides atoms' coordinates and atomic numbers as node features
3. It only provides 12 regression targets.
Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_, Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
`"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_ `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics: Statistics:
- Number of graphs: 13,0831 - Number of graphs: 130,831
- Number of regression targets: 12 - Number of regression targets: 12
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
......
""" QM9 dataset for graph property prediction (regression) """
import os
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F
class QM9EdgeDataset(DGLDataset):
r"""QM9Edge dataset for graph property prediction (regression)
This dataset consists of 130,831 molecules with 19 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
3. It provides another 7 regression tasks(from 12 to 19).
This class is built based on a preprocessed dataset version, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_
Reference:
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
For
Statistics:
- Number of graphs: 130,831.
- Number of regression targets: 19.
Node attributes:
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
Edge attributes:
- edge_attr: the 4D bond features.
Regression targets:
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0_atom| :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U_atom | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H_atom | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G_atom | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| A | :math:`A` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| B | :math:`B` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+---------------------------------------------------------------------------------------------------------------------------------+
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, it will load all the labels.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False.
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels
2
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
... print(labels) # get labels of the corresponding graph
... # your code here...
>>>
"""
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
map_dict = {}
for i, key in enumerate(keys):
map_dict[key] = i
def __init__(self,
label_keys=None,
raw_dir=None,
force_reload=False,
verbose=True):
if label_keys == None:
self.label_keys = None
self.num_labels = 19
else:
self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys)
self._url = _get_dgl_url('dataset/qm9_edge.npz')
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir,
url=self._url,
force_reload=force_reload,
verbose=verbose)
def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
def process(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node']
self.n_edge = data_dict['n_edge']
self.node_attr = data_dict['node_attr']
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.target = data_dict['target']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
return os.path.exists(npz_path)
def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
n_node=self.n_node,
n_edge=self.n_edge,
node_attr=self.node_attr,
node_pos=self.node_pos,
edge_attr=self.edge_attr,
src=self.src,
dst=self.dst,
targets=self.targets)
def load(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node']
self.n_edge = data_dict['n_edge']
self.node_attr = data_dict['node_attr']
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.targets = data_dict['targets']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond
Tensor
Property values of molecular graphs
"""
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
return g, label
def __len__(self):
r""" Number of graphs in the dataset.
Returns
-------
int
"""
return self.n_node.shape[0]
QM9Edge = QM9EdgeDataset
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment