Unverified Commit 0b47e868 authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Example] Add DimeNet(++) for Molecular Graph Property Prediction (#2706)



* [example] arma

* update

* update

* update

* update

* update

* [example] dimenet

* [docs] update dimenet

* [docs] update tf results

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent c88fca50
import sympy as sym
import torch
import torch.nn as nn
from modules.basis_utils import bessel_basis, real_sph_harm
from modules.envelope import Envelope
class SphericalBasisLayer(nn.Module):
def __init__(self,
num_spherical,
num_radial,
cutoff,
envelope_exponent=5):
super(SphericalBasisLayer, self).__init__()
assert num_radial <= 64
self.num_radial = num_radial
self.num_spherical = num_spherical
self.cutoff = cutoff
self.envelope = Envelope(envelope_exponent)
# retrieve formulas
self.bessel_formulas = bessel_basis(num_spherical, num_radial) # x, [num_spherical, num_radial] sympy functions
self.sph_harm_formulas = real_sph_harm(num_spherical) # theta, [num_spherical, ] sympy functions
self.sph_funcs = []
self.bessel_funcs = []
# convert to torch functions
x = sym.symbols('x')
theta = sym.symbols('theta')
modules = {'sin': torch.sin, 'cos': torch.cos}
for i in range(num_spherical):
if i == 0:
first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0)
self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph)
else:
self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules))
for j in range(num_radial):
self.bessel_funcs.append(sym.lambdify([x], self.bessel_formulas[i][j], modules))
def get_bessel_funcs(self):
return self.bessel_funcs
def get_sph_funcs(self):
return self.sph_funcs
\ No newline at end of file
"""QM9 dataset for graph property prediction (regression)."""
import os
import numpy as np
import scipy.sparse as sp
import torch
import dgl
from tqdm import trange
from dgl.data import QM9Dataset
from dgl.data.utils import load_graphs, save_graphs
from dgl.convert import graph as dgl_graph
class QM9(QM9Dataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 130,831 molecules with 12 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 130,831
- Number of regression targets: 12
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
edge_funcs: list
A list of edge-wise user-defined functions <https://docs.dgl.ai/en/0.6.x/api/python/udf.html#edge-wise-user-defined-function> for chemical bonds. Default: None
cutoff: float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
2
>>>
>>> # iterate over the dataset
>>> for g, label in data:
... R = g.ndata['R'] # get coordinates of each atom
... Z = g.ndata['Z'] # get atomic numbers of each atom
... # your code here...
>>>
"""
def __init__(self,
label_keys,
edge_funcs=None,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
self.edge_funcs = edge_funcs
self._keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
super(QM9, self).__init__(label_keys=label_keys,
cutoff=cutoff,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def has_cache(self):
""" step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
return os.path.exists(graph_path) and os.path.exists(line_graph_path)
def process(self):
""" step 3 """
npz_path = f'{self.raw_dir}/qm9_eV.npz'
data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule,
# data_dict['R'] consists of the atomic coordinates,
# data_dict['Z'] consists of the atomic numbers.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N']
self.R = data_dict['R']
self.Z = data_dict['Z']
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
# graph labels
self.label_dict = {}
for k in self._keys:
self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32)
self.label = torch.stack([self.label_dict[key] for key in self.label_keys], dim=1)
# graphs & features
self.graphs, self.line_graphs = self._load_graph()
def _load_graph(self):
num_graphs = self.label.shape[0]
graphs = []
line_graphs = []
for idx in trange(num_graphs):
n_atoms = self.N[idx]
# get all the atomic coordinates of the idx-th molecular graph
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
# calculate the distance between all atoms
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
# keep all edges that don't exceed the cutoff and delete self-loops
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool)
adj = adj.tocoo()
u, v = torch.tensor(adj.row), torch.tensor(adj.col)
g = dgl_graph((u, v))
g.ndata['R'] = torch.tensor(R, dtype=torch.float32)
g.ndata['Z'] = torch.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], dtype=torch.long)
# add user-defined features
if self.edge_funcs is not None:
for func in self.edge_funcs:
g.apply_edges(func)
graphs.append(g)
l_g = dgl.line_graph(g, backtracking=False)
line_graphs.append(l_g)
return graphs, line_graphs
def save(self):
""" step 4 """
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
save_graphs(str(graph_path), self.graphs, self.label_dict)
save_graphs(str(line_graph_path), self.line_graphs)
def load(self):
""" step 5 """
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
self.graphs, label_dict = load_graphs(graph_path)
self.line_graphs, _ = load_graphs(line_graph_path)
self.label = torch.stack([label_dict[key] for key in self.label_keys], dim=1)
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['R']``: the coordinates of each atom
- ``ndata['Z']``: the atomic number
Tensor
Property values of molecular graphs
"""
return self.graphs[idx], self.line_graphs[idx], self.label[idx]
...@@ -20,8 +20,10 @@ class QM9Dataset(DGLDataset): ...@@ -20,8 +20,10 @@ class QM9Dataset(DGLDataset):
2. It only provides atoms' coordinates and atomic numbers as node features 2. It only provides atoms' coordinates and atomic numbers as node features
3. It only provides 12 regression targets. 3. It only provides 12 regression targets.
Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_, Reference:
`"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics: Statistics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment