"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5d8b1987ecae5a9ee802ea2f0fdf55acd4a868af"
Unverified Commit 2c6d0716 authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Feature] QM9 Dataset Support (#2521)



* [Model] MixHop for Node Classification task

* [docs] update

* [docs] update

* [fix] remove seed option

* [fix] update readme

* [feature] support qm9 dataset

* [style] update

* [docs] fix the details

* [fix] indexing only support int

* [style] update

* [fix] multiple backends support

* [docs] add qm9

* [fix] Z type: float32 -> int32

* [fix] Z type: int32 -> long

* [docs] add ref

* [docs] fix

* [docs] update

* [docs] update

* [fix] test eval

* [docs] fix example
Co-authored-by: default avatarxnuohz@126.com <ubuntu@ip-172-31-44-184.us-east-2.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 0c156573
......@@ -168,6 +168,12 @@ QM7b dataset
.. autoclass:: QM7bDataset
:members: __getitem__, __len__
.. _qm9data:
QM9 dataset
```````````````````````````````````
.. autoclass:: QM9Dataset
:members: __getitem__, __len__
.. _minigcdataset:
......
""" The main file to train a MixHop model using a full graph """
import argparse
import copy
import torch
import torch.optim as optim
import torch.nn as nn
......@@ -207,6 +208,7 @@ def main(args):
batchnorm=True)
model = model.to(device)
best_model = copy.deepcopy(model)
# Step 3: Create training components ===================================================== #
loss_fn = nn.CrossEntropyLoss()
......@@ -252,11 +254,12 @@ def main(args):
else:
no_improvement = 0
acc = valid_acc
best_model = copy.deepcopy(model)
scheduler.step()
model.eval()
logits = model(graph, feats)
best_model.eval()
logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
......
......@@ -21,6 +21,7 @@ from .bitcoinotc import BitcoinOTC, BitcoinOTCDataset
from .gdelt import GDELT, GDELTDataset
from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset
from .qm9 import QM9, QM9Dataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
......
"""QM9 dataset for graph property prediction (regression)."""
import os
import numpy as np
import scipy.sparse as sp
from .dgl_dataset import DGLDataset
from .utils import download, _get_dgl_url
from ..convert import graph as dgl_graph
from ..transform import to_bidirected
from .. import backend as F
class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 13,0831 molecules with 12 regression targets.
Node means atom and edge means bond.
Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
`"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 13,0831
- Number of regression targets: 12
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
2
>>>
>>> # iterate over the dataset
>>> for g, label in data:
... R = g.ndata['R'] # get coordinates of each atom
... Z = g.ndata['Z'] # get atomic numbers of each atom
... # your code here...
>>>
"""
def __init__(self,
label_keys,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
self.cutoff = cutoff
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz')
super(QM9Dataset, self).__init__(name='qm9',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz'
data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N']
self.R = data_dict['R']
self.Z = data_dict['Z']
self.label = np.stack([data_dict[key] for key in self.label_keys], axis=1)
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
def download(self):
file_path = f'{self.raw_dir}/qm9_eV.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
@property
def num_labels(self):
r"""
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
"""
return self.label.shape[1]
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['R']``: the coordinates of each atom
- ``ndata['Z']``: the atomic number
Tensor
Property values of molecular graphs
"""
label = F.tensor(self.label[idx], dtype=F.data_type_dict['float32'])
n_atoms = self.N[idx]
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool)
adj = adj.tocoo()
u, v = F.tensor(adj.row), F.tensor(adj.col)
g = dgl_graph((u, v))
g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64'])
return g, label
def __len__(self):
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self.label.shape[0]
QM9 = QM9Dataset
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment