model.py 12.5 KB
Newer Older
yuhai's avatar
yuhai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import math
import inspect
import numpy as np
import torch
import torch.nn as nn 
from torch.nn import functional as F
from deepks.utils import load_basis, get_shell_sec
from deepks.utils import load_elem_table, save_elem_table

SCALE_EPS = 1e-8


def parse_actv_fn(code):
    if callable(code):
        return code
    assert type(code) is str
    lcode = code.lower()
    if lcode == 'sigmoid':
        return torch.sigmoid
    if lcode == 'tanh':
        return torch.tanh
    if lcode == 'relu':
        return torch.relu
    if lcode == 'softplus':
        return F.softplus
    if lcode == 'silu':
        return F.silu
    if lcode == 'gelu':
        return F.gelu
    if lcode == 'mygelu':
        return mygelu
    raise ValueError(f'{code} is not a valid activation function')


def make_embedder(type, shell_sec, **kwargs):
    ltype = type.lower()
    if ltype in ("trace", "sum"):
        EmbdCls = TraceEmbedding
    elif ltype in ("thermal", "softmax"):
        EmbdCls = ThermalEmbedding
    else:
        raise ValueError(f'{type} is not a valid embedding type')
    embedder = EmbdCls(shell_sec, **kwargs)
    return embedder


def mygelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def log_args(name):
    def decorator(func):
        def warpper(self, *args, **kwargs):
            args_dict = inspect.getcallargs(func, self, *args, **kwargs)
            del args_dict['self']
            setattr(self, name, args_dict)
            func(self, *args, **kwargs)
        return warpper
    return decorator


def make_shell_mask(shell_sec):
    lsize = len(shell_sec)
    msize = max(shell_sec)
    mask = torch.zeros(lsize, msize, dtype=bool)
    for l, m in enumerate(shell_sec):
        mask[l, :m] = 1
    return mask


def pad_lastdim(sequences, padding_value=0):
    # assuming trailing dimensions and type of all the Tensors
    # in sequences are same and fetching those from sequences[0]
    max_size = sequences[0].size()
    front_dims = max_size[:-1]
    max_len = max([s.size(-1) for s in sequences])
    out_dims = front_dims + (len(sequences), max_len)
    out_tensor = sequences[0].new_full(out_dims, padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(-1)
        # use index notation to prevent duplicate references to the tensor
        out_tensor[..., i, :length] = tensor
    return out_tensor


def pad_masked(tensor, mask, padding_value=0):
    # equiv to pad_lastdim(tensor.split(shell_sec, dim=-1))
    assert tensor.shape[-1] == mask.sum()
    new_shape = tensor.shape[:-1] + mask.shape
    return tensor.new_full(new_shape, padding_value).masked_scatter_(mask, tensor) 


def unpad_lastdim(padded, length_list):
    # inverse of pad_lastdim
    return [padded[...,i,:length] for i, length in enumerate(length_list)]


def unpad_masked(padded, mask):
    # equiv to torch.cat(unpad_lastdim(padded, shell_sec), dim=-1)
    new_shape = padded.shape[:-mask.ndim] + (mask.sum(),)
    return torch.masked_select(padded, mask).reshape(new_shape)


def masked_softmax(input, mask, dim=-1):
    exps = torch.exp(input - input.max(dim=dim, keepdim=True)[0])
    mexps = exps * mask.to(exps)
    msums = mexps.sum(dim=dim, keepdim=True).clamp(1e-10)
    return mexps / msums


class DenseNet(nn.Module):
    
    def __init__(self, sizes, actv_fn=torch.relu, use_resnet=True, with_dt=False):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(in_f, out_f) 
                                     for in_f, out_f in zip(sizes, sizes[1:])])
        self.actv_fn = actv_fn
        self.use_resnet = use_resnet
        if with_dt:
            self.dts = nn.ParameterList(
                [nn.Parameter(torch.normal(torch.ones(out_f), std=0.01)) 
                    for out_f in sizes[1:]])
        else:
            self.dts = None
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            tmp = layer(x)
            if i < len(self.layers) - 1:
                tmp = self.actv_fn(tmp)
            if self.use_resnet and layer.in_features == layer.out_features:
                if self.dts is not None:
                    tmp = tmp * self.dts[i]
                x = x + tmp
            else:
                x = tmp
        return x


class TraceEmbedding(nn.Module):

    def __init__(self, shell_sec):
        super().__init__()
        self.shell_sec = shell_sec
        self.ndesc = len(shell_sec)
    
    def forward(self, x):
        x_shells = x.split(self.shell_sec, dim=-1)
        tr_shells = [sx.sum(-1, keepdim=True) for sx in x_shells]
        return torch.cat(tr_shells, dim=-1)
    

class ThermalEmbedding(nn.Module):

    def __init__(self, shell_sec, embd_sizes=None, init_beta=5., 
                 momentum=None, max_memory=1000):
        super().__init__()
        self.shell_sec = shell_sec
        self.register_buffer("shell_mask", make_shell_mask(shell_sec), False)# shape: [l, m]
        if embd_sizes is None:
            embd_sizes = shell_sec
        if isinstance(embd_sizes, int):
            embd_sizes = [embd_sizes] * len(shell_sec)
        assert len(embd_sizes) == len(shell_sec)
        self.embd_sizes = embd_sizes
        self.register_buffer("embd_mask", make_shell_mask(embd_sizes), False)
        self.ndesc = sum(embd_sizes)
        self.beta = nn.Parameter( # shape: [l, p], padded
            pad_lastdim([torch.linspace(init_beta, -init_beta, ne) 
                            for ne in embd_sizes]))
        self.momentum = momentum
        self.max_memory = max_memory
        self.register_buffer('running_mean', torch.zeros(len(shell_sec)))
        self.register_buffer('running_var', torch.ones(len(shell_sec)))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

    def forward(self, x):
        x_padded = pad_masked(x, self.shell_mask, 0.) # shape: [n, a, l, m]
        if self.training:
            self.update_running_stats(x_padded)
        nx_padded = ((x_padded - self.running_mean.unsqueeze(-1)) 
                    / (self.running_var.sqrt().unsqueeze(-1) + SCALE_EPS)
                    * self.shell_mask.to(x_padded))
        weight = masked_softmax(
            torch.einsum("...lm,lp->...lmp", nx_padded, -self.beta),
            self.shell_mask.unsqueeze(-1), dim=-2)
        desc_padded = torch.einsum("...m,...mp->...p", x_padded, weight)
        return unpad_masked(desc_padded, self.embd_mask)

    def update_running_stats(self, x_padded):
        self.num_batches_tracked += 1
        if self.momentum is None and self.num_batches_tracked > self.max_memory:
            return # stop update after 1000 batches, so the scaling becomes a fixed parameter
        exp_factor = 1. - 1. / float(self.num_batches_tracked)
        if self.momentum is not None:
            exp_factor = max(exp_factor, self.momentum)
        with torch.no_grad():
            fmask = self.shell_mask.to(x_padded)
            pad_portion = fmask.mean(-1)
            x_masked = x_padded * fmask # make sure padded part is zero
            reduced_dim = (*range(x_masked.ndim-2), -1)
            batch_mean = x_masked.mean(reduced_dim) / pad_portion
            batch_var = x_masked.var(reduced_dim) / pad_portion
            self.running_mean[:] = exp_factor * self.running_mean + (1-exp_factor) * batch_mean
            self.running_var[:] = exp_factor * self.running_var + (1-exp_factor) * batch_var
        
    def reset_running_stats(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.num_batches_tracked.zero_()


class CorrNet(nn.Module):

    @log_args('_init_args')
    def __init__(self, input_dim, hidden_sizes=(100,100,100), 
                 actv_fn='gelu', use_resnet=True, 
                 embedding=None, proj_basis=None, elem_table=None,
                 input_shift=0, input_scale=1, output_scale=1):
        super().__init__()
        actv_fn = parse_actv_fn(actv_fn)
        self.input_dim = input_dim
        # basis info
        self._pbas = load_basis(proj_basis)
        self._init_args["proj_basis"] = self._pbas
        self.shell_sec = None
        # elem const
        if isinstance(elem_table, str):
            elem_table = load_elem_table(elem_table)
            self._init_args["elem_table"] = elem_table
        self.elem_table = elem_table
        self.elem_dict = None if elem_table is None else dict(zip(*elem_table))
        # linear fitting
        self.linear = nn.Linear(input_dim, 1).double()
        # embedding net
        ndesc = input_dim
        self.embedder = None
        if embedding is not None:
            if isinstance(embedding, str):
                embedding = {"type": embedding}
            assert isinstance(embedding, dict)
            raw_shell_sec = get_shell_sec(self._pbas)
            self.shell_sec = raw_shell_sec * (input_dim // sum(raw_shell_sec))
            assert sum(self.shell_sec) == input_dim
            self.embedder = make_embedder(**embedding, shell_sec=self.shell_sec).double()
            self.linear.requires_grad_(False) # make sure it is symmetric
            ndesc = self.embedder.ndesc
        # fitting net
        layer_sizes = [ndesc, *hidden_sizes, 1]
        self.densenet = DenseNet(layer_sizes, actv_fn, use_resnet).double()
        # scaling part
        self.input_shift = nn.Parameter(
            torch.tensor(input_shift, dtype=torch.float64).expand(input_dim).clone(), 
            requires_grad=False)
        self.input_scale = nn.Parameter(
            torch.tensor(input_scale, dtype=torch.float64).expand(input_dim).clone(), 
            requires_grad=False)
        self.output_scale = nn.Parameter(
            torch.tensor(output_scale, dtype=torch.float64), 
            requires_grad=False)
        self.energy_const = nn.Parameter(
            torch.tensor(0, dtype=torch.float64), 
            requires_grad=False)
    
    def forward(self, x):
        # x: nframes x natom x nfeature
        x = (x - self.input_shift) / (self.input_scale + SCALE_EPS)
        l = self.linear(x)
        if self.embedder is not None:
            x = self.embedder(x)
        y = self.densenet(x)
        y = y / self.output_scale + l
        e = y.sum(-2) + self.energy_const
        return e
    
    def get_elem_const(self, elems):
        if self.elem_dict is None:
            return 0.
        return sum(self.elem_dict[ee] for ee in elems)

    def set_normalization(self, shift=None, scale=None):
        dtype = self.input_scale.dtype
        if shift is not None:
            self.input_shift.data[:] = torch.tensor(shift, dtype=dtype)
        if scale is not None:
            self.input_scale.data[:] = torch.tensor(scale, dtype=dtype)

    def set_prefitting(self, weight, bias, trainable=False):
        dtype = self.linear.weight.dtype
        self.linear.weight.data[:] = torch.tensor(weight, dtype=dtype).reshape(-1)
        self.linear.bias.data[:] = torch.tensor(bias, dtype=dtype).reshape(-1)
        self.linear.requires_grad_(trainable)

    def set_energy_const(self, const):
        dtype = self.energy_const.dtype
        self.energy_const.data = torch.tensor(const, dtype=dtype).reshape([])

    def save_dict(self, **extra_info):
        dump_dict = {
            "state_dict": self.state_dict(),
            "init_args": self._init_args,
            "extra_info": extra_info
        }
        return dump_dict

    def save(self, filename, **extra_info):
        torch.save(self.save_dict(**extra_info), filename)

    def compile(self, set_eval=True, **kwargs):
        old_mode = self.training
        if set_eval:
            self.eval()
        smodel = torch.jit.trace(
            self.forward, 
            torch.empty((2, 2, self.input_dim)),
            **kwargs)
        self.train(old_mode)
        return smodel

    def compile_save(self, filename, **kwargs):
        torch.jit.save(self.compile(**kwargs), filename)
        if self.elem_table is not None:
            save_elem_table(filename+".elemtab", self.elem_table)
    
    @staticmethod
    def load_dict(checkpoint, strict=False):
        init_args = checkpoint["init_args"]
        if "layer_sizes" in init_args:
            layers = init_args.pop("layer_sizes")
            init_args["input_dim"] = layers[0]
            init_args["hidden_sizes"] = layers[1:-1]
        model = CorrNet(**init_args)
        model.load_state_dict(checkpoint['state_dict'], strict=strict)
        return model

    @staticmethod
    def load(filename, strict=False):
        try:
            return torch.jit.load(filename)
        except RuntimeError:
            checkpoint = torch.load(filename, map_location="cpu")
            return CorrNet.load_dict(checkpoint, strict=strict)