Unverified Commit 36c6c649 authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

Replace QM9_v2 with built-in QM9Edge Dataset (#3026)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 305d5c16
import numpy as np
import os
from tqdm import tqdm
import torch as th
import dgl
from dgl.data.dgl_dataset import DGLDataset
from dgl.data.utils import download, load_graphs, _get_dgl_url, extract_archive
class QM9DatasetV2(DGLDataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 130,831 molecules with 19 regression targets.
Node means atom and edge means bond.
Reference: `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
Atom features come from `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
Statistics:
- Number of graphs: 130,831
- Number of regression targets: 19
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0_atom| :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U_atom | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H_atom | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G_atom | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| A | :math:`A` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| B | :math:`B` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| c | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+----------------------------------------
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, will load all the labels.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9DatasetV2(label_keys=['mu', 'alpha'])
>>> data.num_labels
>>> # make each graph dense
>>> data.to_dense()
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
... print(labels) # get labels of the corresponding graph
... # your code here...
>>>
"""
def __init__(self,
label_keys = None,
raw_dir=None,
force_reload=False,
verbose=True):
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_ver2.zip')
super(QM9DatasetV2, self).__init__(name='qm9_v2',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
print('begin loading dataset')
graphs, label_dict = load_graphs(os.path.join(self.raw_dir, 'qm9_v2.bin'))
self.graphs = graphs
if self.label_keys == None:
self.labels = np.stack([label_dict[key] for key in label_dict.keys()], axis=1)
else:
self.labels = np.stack([label_dict[key] for key in self.label_keys], axis=1)
def to_dense(self):
r""" Transfrom each graph to a dense graph and add additional edge attribute(distance between two atoms)
Note: This operation will deprecate graph.ndata['pos']
"""
n_graph = self.labels.shape[0]
for id in tqdm(range(n_graph), desc = 'processing graphs'):
graph = self.graphs[id]
n_nodes = graph.num_nodes()
row = th.arange(n_nodes, dtype = th.long)
col = th.arange(n_nodes, dtype = th.long)
row = row.view(-1,1).repeat(1, n_nodes).view(-1)
col = col.repeat(n_nodes)
src = graph.edges()[0]
dst = graph.edges()[1]
idx = src * n_nodes + dst
size = list(graph.edata['edge_attr'].size())
size[0] = n_nodes * n_nodes
edge_attr = graph.edata['edge_attr'].new_zeros(size)
edge_attr[idx] = graph.edata['edge_attr']
pos = graph.ndata['pos']
dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim = -1)
new_graph = dgl.graph((row,col))
new_graph.ndata['attr'] = graph.ndata['attr']
new_graph.edata['edge_attr'] = new_edge_attr
new_graph = new_graph.remove_self_loop()
self.graphs[id] = new_graph
def download(self):
file_path = f'{self.raw_dir}/qm9_v2.zip'
if not os.path.exists(file_path):
download(self._url, path=file_path)
extract_archive(file_path, self.raw_dir, overwrite = True)
@property
def num_labels(self):
r"""
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
"""
return self.labels.shape[1]
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the atomic attributes
- ``edata['edge_attr']``: the bond attributes
Tensor
Property values of molecular graphs
"""
return self.graphs[idx], self.labels[idx]
def __len__(self):
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self.labels.shape[0]
...@@ -4,7 +4,7 @@ import torch.nn.functional as F ...@@ -4,7 +4,7 @@ import torch.nn.functional as F
import dgl import dgl
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from dgl.data.utils import Subset from dgl.data.utils import Subset
from qm9_v2 import QM9DatasetV2 from dgl.data import QM9EdgeDataset
from model import InfoGraphS from model import InfoGraphS
import argparse import argparse
...@@ -39,6 +39,69 @@ def argument(): ...@@ -39,6 +39,69 @@ def argument():
return args return args
class DenseQM9EdgeDataset(QM9EdgeDataset):
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond
Tensor
Property values of molecular graphs
"""
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl.graph((src, dst))
g.ndata['pos'] = th.tensor(pos).float()
g.ndata['attr'] = th.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]]).float()
g.edata['edge_attr'] = th.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]).float()
label = th.tensor(self.targets[idx][self.label_keys]).float()
n_nodes = g.num_nodes()
row = th.arange(n_nodes)
col = th.arange(n_nodes)
row = row.view(-1,1).repeat(1, n_nodes).view(-1)
col = col.repeat(n_nodes)
src = g.edges()[0]
dst = g.edges()[1]
idx = src * n_nodes + dst
size = list(g.edata['edge_attr'].size())
size[0] = n_nodes * n_nodes
edge_attr = g.edata['edge_attr'].new_zeros(size)
edge_attr[idx] = g.edata['edge_attr']
pos = g.ndata['pos']
dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim = -1)
graph = dgl.graph((row,col))
graph.ndata['attr'] = g.ndata['attr']
graph.edata['edge_attr'] = new_edge_attr
graph = graph.remove_self_loop()
return graph, label
def collate(samples): def collate(samples):
''' collate function for building graph dataloader ''' ''' collate function for building graph dataloader '''
...@@ -76,13 +139,10 @@ if __name__ == '__main__': ...@@ -76,13 +139,10 @@ if __name__ == '__main__':
label_keys = [args.target] label_keys = [args.target]
print(args) print(args)
dataset = QM9DatasetV2(label_keys) dataset = DenseQM9EdgeDataset(label_keys = label_keys)
dataset.to_dense()
graphs = dataset.graphs
# Train/Val/Test Splitting # Train/Val/Test Splitting
N = len(graphs) N = dataset.targets.shape[0]
all_idx = np.arange(N) all_idx = np.arange(N)
np.random.shuffle(all_idx) np.random.shuffle(all_idx)
...@@ -114,7 +174,6 @@ if __name__ == '__main__': ...@@ -114,7 +174,6 @@ if __name__ == '__main__':
shuffle=True) shuffle=True)
# generate validation & testing dataloader # generate validation & testing dataloader
val_loader = GraphDataLoader(val_data, val_loader = GraphDataLoader(val_data,
batch_size=args.val_batch_size, batch_size=args.val_batch_size,
collate_fn=collate, collate_fn=collate,
...@@ -129,13 +188,6 @@ if __name__ == '__main__': ...@@ -129,13 +188,6 @@ if __name__ == '__main__':
print('======== target = {} ========'.format(args.target)) print('======== target = {} ========'.format(args.target))
mean = dataset.labels.mean().item()
std = dataset.labels.std().item()
print('mean = {:4f}'.format(mean))
print('std = {:4f}'.format(std))
in_dim = dataset[0][0].ndata['attr'].shape[1] in_dim = dataset[0][0].ndata['attr'].shape[1]
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
...@@ -169,9 +221,9 @@ if __name__ == '__main__': ...@@ -169,9 +221,9 @@ if __name__ == '__main__':
sup_graph = sup_graph.to(args.device) sup_graph = sup_graph.to(args.device)
unsup_graph = unsup_graph.to(args.device) unsup_graph = unsup_graph.to(args.device)
sup_nfeat, sup_efeat = sup_graph.ndata['attr'], sup_graph.ndata['edge_attr'] sup_nfeat, sup_efeat = sup_graph.ndata['attr'], sup_graph.edata['edge_attr']
unsup_nfeat, unsup_efeat, unsup_graph_id = unsup_graph.ndata['attr'],\ unsup_nfeat, unsup_efeat, unsup_graph_id = unsup_graph.ndata['attr'],\
unsup_graph.edata['edge_attr'], unsup_graph.edata['graph_id'] unsup_graph.edata['edge_attr'], unsup_graph.ndata['graph_id']
sup_target = sup_target sup_target = sup_target
sup_target = sup_target.to(args.device) sup_target = sup_target.to(args.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment