"include/ck/library/utility/check_err.hpp" did not exist on "37072aac0b382b6f2838176e1708e6884b6f8abd"
addons.py 5.19 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import time
import torch
import numpy as np
from torch import nn
from pyscf import lib
from pyscf.lib import logger
from pyscf import gto
from pyscf import scf, dft
from deepks.scf.scf import t_make_eig, t_make_grad_eig_dm


def t_ele_grad(bfock, c_vir, c_occ, n_occ):
    g = torch.einsum("pa,qi,...pq->...ai", c_vir, c_occ*n_occ, bfock)
    return g.flatten(-2)


def make_grad_eig_egrad(dscf, mo_coeff=None, mo_occ=None, gfock=None):
    if mo_occ is None: 
        mo_occ = dscf.mo_occ
    if mo_coeff is None: 
        mo_coeff = dscf.mo_coeff
    if gfock is None:
        dm = dscf.make_rdm1(mo_coeff, mo_occ)
        if dm.ndim >= 3 and isinstance(dscf, scf.uhf.UHF):
            dm = dm.sum(0)
        gfock = t_make_grad_eig_dm(torch.from_numpy(dm), dscf._t_ovlp_shells).numpy()
    if mo_coeff.ndim >= 3 and mo_occ.ndim >= 2:
        return np.concatenate([make_grad_eig_egrad(dscf, mc, mo, gfock) 
            for mc, mo in zip(mo_coeff, mo_occ)], axis=-1)
    iocc = mo_occ>0
    t_no = torch.from_numpy(mo_occ[iocc]).to(dscf.device)
    t_co = torch.from_numpy(mo_coeff[:, iocc]).to(dscf.device)
    t_cv = torch.from_numpy(mo_coeff[:, ~iocc]).to(dscf.device)
    t_gfock = torch.from_numpy(gfock).to(dscf.device)
    return t_ele_grad(t_gfock, t_cv, t_co, t_no).cpu().numpy()


def gen_coul_loss(dscf, fock=None, ovlp=None, mo_occ=None):
    nao = dscf.mol.nao
    fock = (fock if fock is not None else dscf.get_fock()).reshape(-1, nao, nao)
    s1e = ovlp if ovlp is not None else dscf.get_ovlp()
    mo_occ = (mo_occ if mo_occ is not None else dscf.mo_occ).reshape(-1, nao)
    def _coul_loss_grad(v, target_dm):
        # return coulomb loss and its grad with respect to fock matrix
        # only support single dm, do not use directly for UHF
        a_loss = 0.
        a_grad = 0.
        target_dm = target_dm.reshape(fock.shape)
        for tdm, f1e, nocc in zip(target_dm, fock, mo_occ):
            iocc = nocc>0
            moe, moc = dscf._eigh(f1e+v, s1e)
            eo, ev = moe[iocc], moe[~iocc]
            co, cv = moc[:, iocc], moc[:, ~iocc]
            dm = (co * nocc[iocc]) @ co.T
            # calc loss
            ddm = dm - tdm
            dvj = dscf.get_j(dm=ddm)
            loss = 0.5 * np.einsum("ij,ji", ddm, dvj)
            a_loss += loss
            # calc grad with respect to fock matrix
            ie_mn = 1. / (-ev.reshape(-1, 1) + eo)
            temp_mn = cv.T @ dvj @ co * nocc[iocc] * ie_mn
            dldv = cv @ temp_mn @ co.T
            dldv = dldv + dldv.T
            a_grad += dldv
        return a_loss, a_grad
    return _coul_loss_grad


def make_grad_coul_veig(dscf, target_dm):
    clfn = gen_coul_loss(dscf)
    dm = dscf.make_rdm1()
    if dm.ndim == 3 and isinstance(dscf, scf.uhf.UHF):
        dm = dm.sum(0)
    t_dm = torch.from_numpy(dm).requires_grad_()
    t_eig = t_make_eig(t_dm, dscf._t_ovlp_shells).requires_grad_()
    loss, dldv = clfn(np.zeros_like(dm), target_dm)
    t_veig = torch.zeros_like(t_eig).requires_grad_()
    [t_vc] = torch.autograd.grad(t_eig, t_dm, t_veig, create_graph=True)
    [t_ghead] = torch.autograd.grad(t_vc, t_veig, torch.from_numpy(dldv))
    return t_ghead.detach().cpu().numpy()


def calc_optim_veig(dscf, target_dm, 
                    target_dec=None, gvx=None, 
                    nstep=1, force_factor=1., **optim_args):
    clfn = gen_coul_loss(dscf, fock=dscf.get_fock(vhf=dscf.get_veff0()))
    dm = dscf.make_rdm1()
    if dm.ndim == 3 and isinstance(dscf, scf.uhf.UHF):
        dm = dm.sum(0)
    t_dm = torch.from_numpy(dm).requires_grad_()
    t_eig = t_make_eig(t_dm, dscf._t_ovlp_shells).requires_grad_()
    t_ec = dscf.net(t_eig.to(dscf.device))
    t_veig = torch.autograd.grad(t_ec, t_eig)[0].requires_grad_()
    t_lde = torch.from_numpy(target_dec) if target_dec is not None else None
    t_gvx = torch.from_numpy(gvx) if gvx is not None else None
    # build closure
    def closure():
        [t_vc] = torch.autograd.grad(
            t_eig, t_dm, t_veig, retain_graph=True, create_graph=True)
        loss, dldv = clfn(t_vc.detach().numpy(), target_dm)
        grad = torch.autograd.grad(
            t_vc, t_veig, torch.from_numpy(dldv), only_inputs=True)[0]
        # build closure for force loss
        if t_lde is not None and t_gvx is not None:
            t_pde = torch.tensordot(t_gvx, t_veig)
            lossde = force_factor * torch.sum((t_pde - t_lde)**2)
            grad = grad + torch.autograd.grad(lossde, t_veig, only_inputs=True)[0]
            loss = loss + lossde
        t_veig.grad = grad
        return loss
    # do the optimization
    optim = torch.optim.LBFGS([t_veig], **optim_args)
    tic = (time.process_time(), time.perf_counter())
    for _ in range(nstep):
        optim.step(closure)
        tic = logger.timer(dscf, 'LBFGS step', *tic)
    logger.note(dscf, f"optimized loss for veig = {closure()}")        
    return t_veig.detach().numpy()


def gcalc_optim_veig(gdscf, target_dm, target_grad, 
                     nstep=1, force_factor=1., **optim_args):
    target_dec = target_grad - gdscf.de0
    gvx = gdscf.make_grad_eig_x()
    return calc_optim_veig(
            gdscf.base,
            target_dm=target_dm, 
            target_dec=target_dec, gvx=gvx, 
            nstep=nstep, force_factor=force_factor, **optim_args)