penalty.py 5.19 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import time
import numpy as np
from pyscf.dft import numint, gen_grid
from pyscf.lib import logger
from deepks.utils import check_list


def select_penalty(name):
    name = name.lower()
    if name == "density":
        return DensityPenalty
    if name == "coulomb":
        return CoulombPenalty
    raise ValueError(f"unknown penalty type: {name}")


class PenaltyMixin(object):
    """Mixin class to add penalty potential in Fock matrix"""

    def __init__(self, penalties=None):
        self.penalties = check_list(penalties)
        for pnt in self.penalties:
            pnt.init_hook(self)

    def get_fock(self, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1, 
                 diis=None, diis_start_cycle=None, 
                 level_shift_factor=None, damp_factor=None):
        """modified get_fock method to apply penalty terms onto vhf"""
        if dm is None:
            dm = self.make_rdm1()
        if h1e is None: 
            h1e = self.get_hcore()
        if vhf is None: 
            vhf = self.get_veff(dm=dm)
        vp = sum(pnt.fock_hook(self, dm=dm, h1e=h1e, vhf=vhf, cycle=cycle) 
                    for pnt in self.penalties)
        vhf = vhf + vp
        return super().get_fock(h1e=h1e, s1e=s1e, vhf=vhf, dm=dm, cycle=cycle, 
                        diis=diis, diis_start_cycle=diis_start_cycle, 
                        level_shift_factor=level_shift_factor, damp_factor=damp_factor)


class AbstructPenalty(object):
    """
    Abstruct class for penalty term in scf hamiltonian.
    To implement a penalty one needs to implement 
    fock_hook and (optional) init_hook methods.
    """
    required_labels = [] # the label would be load and pass to __init__

    def init_hook(self, mf, **envs):
        """
        Method to be called when initialize the scf object.
        Used to initialize the penalty with molecule info.
        """
        pass

    def fock_hook(self, mf, dm=None, h1e=None, vhf=None, cycle=-1, **envs):
        """
        Method to be called before get_fock is called.
        The returned matrix would be added to the vhf matrix
        """
        raise NotImplementedError("fock_hook method is not implemented")


class DummyPenalty(AbstructPenalty):
    def fock_hook(self, mf, dm=None, h1e=None, vhf=None, cycle=-1, **envs):
        return 0


class DensityPenalty(AbstructPenalty):
    r"""
    penalty on the difference w.r.t target density
    E_p = \lambda / 2 * \int dx (\rho(x) - \rho_target(x))^2
    V_p = \lambda * \int dx <ao_i|x> (\rho(x) - \rho_target(x)) <x|ao_j> 
    The target density should be given as density matrix in ao basis
    """
    required_labels = ["dm"]

    def __init__(self, target_dm, strength=1, random=False, start_cycle=0):
        if isinstance(target_dm, str):
            target_dm = np.load(target_dm)
        self.dm_t = target_dm
        self.init_strength = strength
        self.strength = strength * np.random.rand() if random else strength
        self.start_cycle = start_cycle
        # below are values to be initialized later in init_hook
        self.grids = None
        self.ao_value = None

    def init_hook(self, mf, **envs):
        if hasattr(mf, "grid"):
            self.grids = mf.grids
        else:
            self.grids = gen_grid.Grids(mf.mol)

    def fock_hook(self, mf, dm=None, h1e=None, vhf=None, cycle=-1, **envs):
        # cycle > 0 means it is doing scf iteration
        if 0 <= cycle < self.start_cycle:
            return 0
        if self.grids.coords is None:
            self.grids.build()
        if self.ao_value is None:
            self.ao_value = numint.eval_ao(mf.mol, self.grids.coords, deriv=0)
        tic = (time.process_time(), time.perf_counter())
        rho_diff = numint.eval_rho(mf.mol, self.ao_value, dm - self.dm_t)
        v_p = numint.eval_mat(mf.mol, self.ao_value, self.grids.weights, rho_diff, rho_diff)
        # cycle < 0 means it is just checking, we only print here
        if cycle < 0 and mf.verbose >=4:
            diff_norm = np.sum(np.abs(rho_diff)*self.grids.weights)
            logger.info(mf, f"  Density Penalty: |diff| = {diff_norm}")
            logger.timer(mf, "dens_pnt", *tic)
        return self.strength * v_p


class CoulombPenalty(AbstructPenalty):
    r"""
    penalty given by the coulomb energy of density difference

    """
    required_labels = ["dm"]

    def __init__(self, target_dm, strength=1, random=False, start_cycle=0):
        if isinstance(target_dm, str):
            target_dm = np.load(target_dm)
        self.dm_t = target_dm
        self.init_strength = strength
        self.strength = strength * np.random.rand() if random else strength
        self.start_cycle = start_cycle

    def fock_hook(self, mf, dm=None, h1e=None, vhf=None, cycle=-1, **envs):
        # cycle > 0 means it is doing scf iteration
        if 0 <= cycle < self.start_cycle:
            return 0
        tic = (time.process_time(), time.perf_counter())
        ddm = dm - self.dm_t
        v_p = mf.get_j(dm=ddm)
        # cycle < 0 means it is just checking, we only print here
        if cycle < 0 and mf.verbose >=4:
            diff_norm = np.sum(ddm * v_p)
            logger.info(mf, f"  Coulomb Penalty: |diff| = {diff_norm}")
            logger.timer(mf, "coul_pnt", *tic)
        return self.strength * v_p