"tests/python/common/transforms/test_transform.py" did not exist on "92f87f48c765e4987b8043155e10d65d1af7dc83"
Unverified Commit 2c6d0716 authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Feature] QM9 Dataset Support (#2521)



* [Model] MixHop for Node Classification task

* [docs] update

* [docs] update

* [fix] remove seed option

* [fix] update readme

* [feature] support qm9 dataset

* [style] update

* [docs] fix the details

* [fix] indexing only support int

* [style] update

* [fix] multiple backends support

* [docs] add qm9

* [fix] Z type: float32 -> int32

* [fix] Z type: int32 -> long

* [docs] add ref

* [docs] fix

* [docs] update

* [docs] update

* [fix] test eval

* [docs] fix example
Co-authored-by: default avatarxnuohz@126.com <ubuntu@ip-172-31-44-184.us-east-2.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 0c156573
...@@ -168,6 +168,12 @@ QM7b dataset ...@@ -168,6 +168,12 @@ QM7b dataset
.. autoclass:: QM7bDataset .. autoclass:: QM7bDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
.. _qm9data:
QM9 dataset
```````````````````````````````````
.. autoclass:: QM9Dataset
:members: __getitem__, __len__
.. _minigcdataset: .. _minigcdataset:
......
""" The main file to train a MixHop model using a full graph """ """ The main file to train a MixHop model using a full graph """
import argparse import argparse
import copy
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch.nn as nn import torch.nn as nn
...@@ -207,6 +208,7 @@ def main(args): ...@@ -207,6 +208,7 @@ def main(args):
batchnorm=True) batchnorm=True)
model = model.to(device) model = model.to(device)
best_model = copy.deepcopy(model)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
...@@ -252,11 +254,12 @@ def main(args): ...@@ -252,11 +254,12 @@ def main(args):
else: else:
no_improvement = 0 no_improvement = 0
acc = valid_acc acc = valid_acc
best_model = copy.deepcopy(model)
scheduler.step() scheduler.step()
model.eval() best_model.eval()
logits = model(graph, feats) logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
......
...@@ -21,6 +21,7 @@ from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset ...@@ -21,6 +21,7 @@ from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT, GDELTDataset from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .qm9 import QM9, QM9Dataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
......
"""QM9 dataset for graph property prediction (regression)."""
import os
import numpy as np
import scipy.sparse as sp
from .dgl_dataset import DGLDataset
from .utils import download, _get_dgl_url
from ..convert import graph as dgl_graph
from ..transform import to_bidirected
from .. import backend as F
class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 13,0831 molecules with 12 regression targets.
Node means atom and edge means bond.
Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
`"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 13,0831
- Number of regression targets: 12
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
2
>>>
>>> # iterate over the dataset
>>> for g, label in data:
... R = g.ndata['R'] # get coordinates of each atom
... Z = g.ndata['Z'] # get atomic numbers of each atom
... # your code here...
>>>
"""
def __init__(self,
label_keys,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
self.cutoff = cutoff
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz')
super(QM9Dataset, self).__init__(name='qm9',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz'
data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N']
self.R = data_dict['R']
self.Z = data_dict['Z']
self.label = np.stack([data_dict[key] for key in self.label_keys], axis=1)
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
def download(self):
file_path = f'{self.raw_dir}/qm9_eV.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
@property
def num_labels(self):
r"""
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
"""
return self.label.shape[1]
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['R']``: the coordinates of each atom
- ``ndata['Z']``: the atomic number
Tensor
Property values of molecular graphs
"""
label = F.tensor(self.label[idx], dtype=F.data_type_dict['float32'])
n_atoms = self.N[idx]
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool)
adj = adj.tocoo()
u, v = F.tensor(adj.row), F.tensor(adj.col)
g = dgl_graph((u, v))
g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64'])
return g, label
def __len__(self):
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self.label.shape[0]
QM9 = QM9Dataset
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment