qm9.py 11.2 KB
Newer Older
1
2
"""QM9 dataset for graph property prediction (regression)."""
import os
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
7
8
import numpy as np
import scipy.sparse as sp
import torch
9
from dgl.convert import graph as dgl_graph
10
11
from dgl.data import QM9Dataset
from dgl.data.utils import load_graphs, save_graphs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
from tqdm import trange
13

14
15
16
17
18
19
20

class QM9(QM9Dataset):
    r"""QM9 dataset for graph property prediction (regression)

    This dataset consists of 130,831 molecules with 12 regression targets.
    Nodes correspond to atoms and edges correspond to bonds.

21
22
    Reference:

23
24
    - `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_
    - `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    Statistics:

    - Number of graphs: 130,831
    - Number of regression targets: 12

    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | Keys   | Property                         | Description                                                                       | Unit                                        |
    +========+==================================+===================================================================================+=============================================+
    | mu     | :math:`\mu`                      | Dipole moment                                                                     | :math:`\textrm{D}`                          |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | alpha  | :math:`\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | homo   | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | lumo   | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | gap    | :math:`\Delta \epsilon`          | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | r2     | :math:`\langle R^2 \rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | zpve   | :math:`\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | U0     | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | U      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | H      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | G      | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\textrm{eV}`                         |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
    | Cv     | :math:`c_{\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    Parameters
    ----------
    label_keys: list
        Names of the regression property, which should be a subset of the keys in the table above.
    edge_funcs: list
        A list of edge-wise user-defined functions <https://docs.dgl.ai/en/0.6.x/api/python/udf.html#edge-wise-user-defined-function> for chemical bonds. Default: None
    cutoff: float
        Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
        Default: 5.0 Angstrom
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
        Whether to print out progress information. Default: True

    Attributes
    ----------
    num_labels : int
        Number of labels for each graph, i.e. number of prediction tasks
80

81
82
83
84
    Raises
    ------
    UserWarning
        If the raw data is changed in the remote server by the author.
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
    Examples
    --------
    >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
    >>> data.num_labels
    2
    >>>
    >>> # iterate over the dataset
    >>> for g, label in data:
    ...     R = g.ndata['R'] # get coordinates of each atom
    ...     Z = g.ndata['Z'] # get atomic numbers of each atom
    ...     # your code here...
    >>>
    """

100
101
102
103
104
105
106
107
108
    def __init__(
        self,
        label_keys,
        edge_funcs=None,
        cutoff=5.0,
        raw_dir=None,
        force_reload=False,
        verbose=False,
    ):
109
        self.edge_funcs = edge_funcs
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        self._keys = [
            "mu",
            "alpha",
            "homo",
            "lumo",
            "gap",
            "r2",
            "zpve",
            "U0",
            "U",
            "H",
            "G",
            "Cv",
        ]

        super(QM9, self).__init__(
            label_keys=label_keys,
            cutoff=cutoff,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
        )
132

133
134
135
136
137
138
139
140
    @property
    def graph_path(self):
        return f"{self.save_path}/dgl_graph.bin"

    @property
    def line_graph_path(self):
        return f"{self.save_path}/dgl_line_graph.bin"

141
    def has_cache(self):
142
        """step 1, if True, goto step 5; else goto download(step 2), then step 3"""
143
144
145
        return os.path.exists(self.graph_path) and os.path.exists(
            self.line_graph_path
        )
146
147

    def process(self):
148
149
        """step 3"""
        npz_path = f"{self.raw_dir}/qm9_eV.npz"
150
151
152
153
154
155
        data_dict = np.load(npz_path, allow_pickle=True)
        # data_dict['N'] contains the number of atoms in each molecule,
        # data_dict['R'] consists of the atomic coordinates,
        # data_dict['Z'] consists of the atomic numbers.
        # Atomic properties (Z and R) of all molecules are concatenated as single tensors,
        # so you need this value to select the correct atoms for each molecule.
156
157
158
        self.N = data_dict["N"]
        self.R = data_dict["R"]
        self.Z = data_dict["Z"]
159
160
161
162
163
164
        self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
        # graph labels
        self.label_dict = {}
        for k in self._keys:
            self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32)

165
166
167
        self.label = torch.stack(
            [self.label_dict[key] for key in self.label_keys], dim=1
        )
168
169
        # graphs & features
        self.graphs, self.line_graphs = self._load_graph()
170

171
172
173
174
    def _load_graph(self):
        num_graphs = self.label.shape[0]
        graphs = []
        line_graphs = []
175

176
177
178
        for idx in trange(num_graphs):
            n_atoms = self.N[idx]
            # get all the atomic coordinates of the idx-th molecular graph
179
            R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]
180
181
182
            # calculate the distance between all atoms
            dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
            # keep all edges that don't exceed the cutoff and delete self-loops
183
            adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(
184
                n_atoms, dtype=np.bool_
185
            )
186
187
188
            adj = adj.tocoo()
            u, v = torch.tensor(adj.row), torch.tensor(adj.col)
            g = dgl_graph((u, v))
189
190
191
192
193
194
            g.ndata["R"] = torch.tensor(R, dtype=torch.float32)
            g.ndata["Z"] = torch.tensor(
                self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],
                dtype=torch.long,
            )

195
196
197
198
199
200
201
202
            # add user-defined features
            if self.edge_funcs is not None:
                for func in self.edge_funcs:
                    g.apply_edges(func)

            graphs.append(g)
            l_g = dgl.line_graph(g, backtracking=False)
            line_graphs.append(l_g)
203

204
205
206
        return graphs, line_graphs

    def save(self):
207
        """step 4"""
208
209
        save_graphs(str(self.graph_path), self.graphs, self.label_dict)
        save_graphs(str(self.line_graph_path), self.line_graphs)
210
211

    def load(self):
212
        """step 5"""
213
214
        self.graphs, label_dict = load_graphs(self.graph_path)
        self.line_graphs, _ = load_graphs(self.line_graph_path)
215
216
217
        self.label = torch.stack(
            [label_dict[key] for key in self.label_keys], dim=1
        )
218
219

    def __getitem__(self, idx):
220
        r"""Get graph and label by index
221
222
223
224
225

        Parameters
        ----------
        idx : int
            Item index
226

227
228
229
230
231
232
233
234
235
236
        Returns
        -------
        dgl.DGLGraph
            The graph contains:
            - ``ndata['R']``: the coordinates of each atom
            - ``ndata['Z']``: the atomic number
        Tensor
            Property values of molecular graphs
        """
        return self.graphs[idx], self.line_graphs[idx], self.label[idx]