run.py 9.6 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
import sys
import time
import numpy as np
import torch
from pyscf import gto, lib
try:
    import deepks
except ImportError as e:
    sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../")
from deepks.scf.scf import DSCF, UDSCF
from deepks.scf.fields import select_fields
from deepks.scf.penalty import select_penalty
from deepks.model.model import CorrNet
from deepks.utils import check_list, flat_file_list
from deepks.utils import is_xyz, load_sys_paths
from deepks.utils import load_yaml, load_array
from deepks.utils import get_sys_name, get_with_prefix

DEFAULT_UNIT = "Bohr"

DEFAULT_FNAMES = {"e_tot", "e_base", "dm_eig", "conv"}

DEFAULT_HF_ARGS = {
    "conv_tol": 1e-9
}

DEFAULT_SCF_ARGS = {
    "conv_tol": 1e-7,
    # "level_shift": 0.1,
    # "diis_space": 20
}

MOL_ATTRIBUTE = {"charge", "basis", "unit"} # other molecule properties

def solve_mol(mol, model, fields, labels=None,
              proj_basis=None, penalties=None, device=None,
              chkfile=None, verbose=0,
              **scf_args):
    
    tic = time.time()

    SCFcls = DSCF if mol.spin == 0 else UDSCF
    cf = SCFcls(mol, model, 
                proj_basis=proj_basis, 
                penalties=penalties, 
                device=device)
    cf.set(chkfile=chkfile, verbose=verbose)
    grid_args = scf_args.pop("grids", {})
    cf.set(**scf_args)
    cf.grids.set(**grid_args)
    cf.kernel()

    natom_raw = mol.natm
    natom = cf._pmol.natm
    nao = mol.nao
    nproj = cf.nproj
    meta = np.array([natom, natom_raw, nao, nproj])

    res = {}
    if labels is None:
        labels = {}
    for fd in fields["scf"]:
        fls = {k:labels[k] for k in fd.required_labels}
        res[fd.name] = fd.calc(cf, **fls)
    if fields["grad"]:
        gd = cf.nuc_grad_method().run()
        for fd in fields["grad"]:
            fls = {k:labels[k] for k in fd.required_labels}
            res[fd.name] = fd.calc(gd, **fls)
    
    tac = time.time()
    if verbose:
        print(f"time of scf: {tac - tic:6.2f}s, converged:   {cf.converged}")

    return meta, res


def get_required_labels(fields=None, penalty_dicts=None):
    field_labels   = [check_list(f.required_labels)
                        for f in check_list(fields)]
    penalty_labels = [check_list(p.get("required_labels", 
                                       select_penalty(p["type"]).required_labels))
                        for p in check_list(penalty_dicts)]
    return set(sum(field_labels + penalty_labels, []))


def system_iter(path, labels=None):
    """
    return an iterator that gives atoms and required labels each time
    path: either an xyz file, or a folder contains (atom.npy | (coord.npy & type.raw))
    labels: a set contains required label names, will be load by $base[.|/]$label.npy
    $base will be the basename of the xyz file (followed by .) or the folder (followed by /)
    """
    if labels is None:
        labels = set()
    base = get_sys_name(path)
    attr_paths = {at: get_with_prefix(at, base, ".npy", True) for at in MOL_ATTRIBUTE}
    attr_paths = {k: v for k, v in attr_paths.items() if v is not None}
    attrs = attr_paths.keys()
    label_paths = {lb: get_with_prefix(lb, base, prefer=".npy") for lb in labels}
    # if xyz, will yield single frame. Assume all labels are single frame
    if is_xyz(path):
        atom = path
        attr_dict = {at: load_array(attr_paths[at]) for at in attrs}
        if "unit" not in attr_dict:
            attr_dict["unit"] = "Angstrom"
        label_dict = {lb: load_array(label_paths[lb]) for lb in labels}
        yield atom, attr_dict, label_dict
        return
    # a folder contains multiple frames data, yield one by one
    else:
        assert os.path.isdir(path), f"system {path} is neither .xyz or dir"
        all_attrs = {at: load_array(attr_paths[at]) for at in attrs}
        all_labels = {lb: load_array(label_paths[lb]) for lb in labels}
        try:
            atom_array = load_array(get_with_prefix("atom", path, prefer=".npy"))
            assert len(atom_array.shape) == 3 and atom_array.shape[2] == 4, atom_array.shape
            nframes = atom_array.shape[0]
            elements = np.rint(atom_array[:, :, 0]).astype(int)
            coords = atom_array[:, :, 1:]
        except FileNotFoundError:
            coords = load_array(get_with_prefix("coord", path, prefer=".npy"))
            assert len(coords.shape) == 3 and coords.shape[2] == 3, coords.shape
            nframes = coords.shape[0]
            elements = np.loadtxt(os.path.join(path, "type.raw"), dtype=str)\
                         .reshape(1,-1).repeat(nframes, axis=0)
        for i in range(nframes):
            atom = [[e,c] for e,c in zip(elements[i], coords[i])]
            attr_dict = {at: (all_attrs[at][i] 
                                if all_attrs[at].ndim > 0
                                and all_attrs[at].shape[0] == nframes
                                else all_attrs[at]) 
                         for at in attrs}
            label_dict = {lb: all_labels[lb][i] for lb in labels}
            yield atom, attr_dict, label_dict
        return


def build_mol(atom, basis='ccpvdz', unit=DEFAULT_UNIT, verbose=0, **kwargs):
    # build a molecule using given atom input
    # set the default basis to cc-pVDZ and use input unit 'Ang"
    mol = gto.Mole()
    # change minimum max memory to 16G
    # mol.max_memory = max(16000, mol.max_memory) 
    if isinstance(unit, np.ndarray):
        unit = unit.tolist()
    mol.unit = unit
    mol.atom = atom
    mol.basis = basis
    mol.verbose = verbose
    mol.set(**kwargs)
    mol.spin = mol.nelectron % 2
    mol.build(0,0)
    return mol


def build_penalty(pnt_dict, label_dict={}):
    pnt_dict = pnt_dict.copy()
    pnt_type = pnt_dict.pop("type")
    PenaltyClass = select_penalty(pnt_type)
    label_names = pnt_dict.pop("required_labels", PenaltyClass.required_labels)
    label_arrays = [label_dict[lb] for lb in check_list(label_names)]
    return PenaltyClass(*label_arrays, **pnt_dict)


def collect_fields(fields, meta, res_list):
    if isinstance(fields, dict):
        fields = sum(fields.values(), [])
    if isinstance(res_list, dict):
        res_list = [res_list]
    nframe = len(res_list)
    natom, natom_raw, nao, nproj = meta
    res_dict = {}
    for fd in fields:
        fd_res = np.array([res[fd.name] for res in res_list])
        if fd.shape:
            fd_shape = eval(fd.shape, {}, locals())
            fd_res = fd_res.reshape(fd_shape)
        res_dict[fd.name] = fd_res
    return res_dict


def dump_meta(dir_name, meta):
    os.makedirs(dir_name, exist_ok = True)
    np.savetxt(os.path.join(dir_name, 'system.raw'), 
               np.reshape(meta, (1,-1)), 
               fmt = '%d', header = 'natom natom_raw nao nproj')


def dump_data(dir_name, **data_dict):
    os.makedirs(dir_name, exist_ok = True)
    for name, value in data_dict.items():
        np.save(os.path.join(dir_name, f'{name}.npy'), value)


def main(systems, model_file="model.pth", basis='ccpvdz', 
         proj_basis=None, penalty_terms=None, device=None,
         dump_dir=".", dump_fields=DEFAULT_FNAMES, group=False, 
         mol_args=None, scf_args=None, verbose=0):
    if model_file is None or model_file.upper() == "NONE":
        model = None
        default_scf_args = DEFAULT_HF_ARGS
    else:
        model = CorrNet.load(model_file).double()
        default_scf_args = DEFAULT_SCF_ARGS

    # check arguments
    penalty_terms = check_list(penalty_terms)
    if mol_args is None: mol_args = {}
    if scf_args is None: scf_args = {}
    scf_args = {**default_scf_args, **scf_args}
    fields = select_fields(dump_fields)
    # check label names from label fields and penalties
    label_names = get_required_labels(fields["scf"]+fields["grad"], penalty_terms)

    if verbose:
        print(f"starting calculation with OMP threads: {lib.num_threads()}",
              f"and max memory: {lib.param.MAX_MEMORY}")
        if verbose > 1:
            print(f"basis: {basis}")
            print(f"specified scf args:\n  {scf_args}")

    meta = old_meta = None
    res_list = []
    systems = load_sys_paths(systems)

    for fl in systems:
        fl = fl.rstrip(os.path.sep)
        for atom, attrs, labels in system_iter(fl, label_names):
            mol_input = {**mol_args, "verbose":verbose, 
                        "atom": atom, "basis": basis,  **attrs}
            mol = build_mol(**mol_input)
            penalties = [build_penalty(pd, labels) for pd in penalty_terms]
            try:
                meta, result = solve_mol(mol, model, fields, labels,
                                         proj_basis=proj_basis, penalties=penalties,
                                         device=device, verbose=verbose, **scf_args)
            except Exception as e:
                print(fl, 'failed! error:', e, file=sys.stderr)
                # continue
                raise
            if group and old_meta is not None and np.any(meta != old_meta):
                break
            res_list.append(result)

        if not group:
            sub_dir = os.path.join(dump_dir, get_sys_name(os.path.basename(fl)))
            dump_meta(sub_dir, meta)
            dump_data(sub_dir, **collect_fields(fields, meta, res_list))
            res_list = []
        elif old_meta is not None and np.any(meta != old_meta):
            print(fl, 'meta does not match! saving previous results only.', file=sys.stderr)
            break
        old_meta = meta
        if verbose:
            print(fl, 'finished')

    if group:
        dump_meta(dump_dir, meta)
        dump_data(dump_dir, **collect_fields(fields, meta, res_list))
        if verbose:
            print('group finished')


if __name__ == "__main__":
    from deepks.main import scf_cli as cli
    cli()