scale.py 14.4 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
from e3nn.util.jit import compile_mode

import sevenn._keys as KEY
from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType


def _as_univ(
    ss: List[float], type_map: Dict[int, int], default: float
) -> List[float]:
    assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long'
    return [
        ss[type_map[z]] if z in type_map else default
        for z in range(NUM_UNIV_ELEMENT)
    ]


@compile_mode('script')
class Rescale(nn.Module):
    """
    Scaling and shifting energy (and automatically force and stress)
    """

    def __init__(
        self,
        shift: float,
        scale: float,
        data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
        data_key_out: str = KEY.ATOMIC_ENERGY,
        train_shift_scale: bool = False,
        **kwargs,
    ):
        assert isinstance(shift, float) and isinstance(scale, float)
        super().__init__()
        self.shift = nn.Parameter(
            torch.FloatTensor([shift]), requires_grad=train_shift_scale
        )
        self.scale = nn.Parameter(
            torch.FloatTensor([scale]), requires_grad=train_shift_scale
        )
        self.key_input = data_key_in
        self.key_output = data_key_out

    def get_shift(self) -> float:
        return self.shift.detach().cpu().tolist()[0]

    def get_scale(self) -> float:
        return self.scale.detach().cpu().tolist()[0]

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        data[self.key_output] = data[self.key_input] * self.scale + self.shift

        return data


@compile_mode('script')
class SpeciesWiseRescale(nn.Module):
    """
    Scaling and shifting energy (and automatically force and stress)
    Use as it is if given list, expand to list if one of them is float
    If two lists are given and length is not the same, raise error
    """

    def __init__(
        self,
        shift: Union[List[float], float],
        scale: Union[List[float], float],
        data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
        data_key_out: str = KEY.ATOMIC_ENERGY,
        data_key_indices: str = KEY.ATOM_TYPE,
        train_shift_scale: bool = False,
    ):
        super().__init__()
        assert isinstance(shift, float) or isinstance(shift, list)
        assert isinstance(scale, float) or isinstance(scale, list)

        if (
            isinstance(shift, list)
            and isinstance(scale, list)
            and len(shift) != len(scale)
        ):
            raise ValueError('List length should be same')

        if isinstance(shift, list):
            num_species = len(shift)
        elif isinstance(scale, list):
            num_species = len(scale)
        else:
            raise ValueError('Both shift and scale is not a list')

        shift = [shift] * num_species if isinstance(shift, float) else shift
        scale = [scale] * num_species if isinstance(scale, float) else scale

        self.shift = nn.Parameter(
            torch.FloatTensor(shift), requires_grad=train_shift_scale
        )
        self.scale = nn.Parameter(
            torch.FloatTensor(scale), requires_grad=train_shift_scale
        )
        self.key_input = data_key_in
        self.key_output = data_key_out
        self.key_indices = data_key_indices

    def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
        """
        Return shift in list of float. If type_map is given, return type_map reversed
        shift, which index equals atomic_number. 0.0 is assigned for atomis not found
        """
        shift = self.shift.detach().cpu().tolist()
        if type_map:
            shift = _as_univ(shift, type_map, 0.0)
        return shift

    def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]:
        """
        Return scale in list of float. If type_map is given, return type_map reversed
        scale, which index equals atomic_number. 1.0 is assigned for atomis not found
        """
        scale = self.scale.detach().cpu().tolist()
        if type_map:
            scale = _as_univ(scale, type_map, 1.0)
        return scale

    @staticmethod
    def from_mappers(
        shift: Union[float, List[float]],
        scale: Union[float, List[float]],
        type_map: Dict[int, int],
        **kwargs,
    ):
        """
        Fit dimensions or mapping raw shift scale values to that is valid under
        the given type_map: (atomic_numbers -> type_indices)
        """
        shift_scale = []
        n_atom_types = len(type_map)
        for s in (shift, scale):
            if isinstance(s, list) and len(s) > n_atom_types:
                if len(s) != NUM_UNIV_ELEMENT:
                    raise ValueError('given shift or scale is strange')
                s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])]
                # s = [s[z] for z in sorted(type_map, key=type_map.get)]
            elif isinstance(s, float):
                s = [s] * n_atom_types
            elif isinstance(s, list) and len(s) == 1:
                s = s * n_atom_types
            shift_scale.append(s)
        assert all([len(s) == n_atom_types for s in shift_scale])
        shift, scale = shift_scale
        return SpeciesWiseRescale(shift, scale, **kwargs)

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        indices = data[self.key_indices]
        data[self.key_output] = data[self.key_input] * self.scale[indices].view(
            -1, 1
        ) + self.shift[indices].view(-1, 1)

        return data


@compile_mode('script')
class ModalWiseRescale(nn.Module):
    """
    Scaling and shifting energy (and automatically force and stress)
    Given shift or scale is either modal-wise and atom-wise or
    not modal-wise but atom-wise. It is always interpreted as atom-wise.
    """

    def __init__(
        self,
        shift: List[List[float]],
        scale: List[List[float]],
        data_key_in: str = KEY.SCALED_ATOMIC_ENERGY,
        data_key_out: str = KEY.ATOMIC_ENERGY,
        data_key_modal_indices: str = KEY.MODAL_TYPE,
        data_key_atom_indices: str = KEY.ATOM_TYPE,
        use_modal_wise_shift: bool = False,
        use_modal_wise_scale: bool = False,
        train_shift_scale: bool = False,
    ):
        super().__init__()
        self.shift = nn.Parameter(
            torch.FloatTensor(shift), requires_grad=train_shift_scale
        )
        self.scale = nn.Parameter(
            torch.FloatTensor(scale), requires_grad=train_shift_scale
        )
        self.key_input = data_key_in
        self.key_output = data_key_out
        self.key_atom_indices = data_key_atom_indices
        self.key_modal_indices = data_key_modal_indices
        self.use_modal_wise_shift = use_modal_wise_shift
        self.use_modal_wise_scale = use_modal_wise_scale
        self._is_batch_data = True

    def get_shift(
        self,
        type_map: Optional[Dict[int, int]] = None,
        modal_map: Optional[Dict[str, int]] = None,
    ) -> Union[List[float], Dict[str, List[float]]]:
        """
        Nothing is given: return as it is
        type_map is given but not modal wise shift: return univ shift
        both type_map and modal_map is given and modal wise shift: return fully
            resolved modalwise univ shift
        """
        shift = self.shift.detach().cpu().tolist()
        if type_map and not self.use_modal_wise_shift:
            shift = _as_univ(shift, type_map, 0.0)
        elif self.use_modal_wise_shift and modal_map and type_map:
            shift = [_as_univ(s, type_map, 0.0) for s in shift]
            shift = {modal: shift[idx] for modal, idx in modal_map.items()}

        return shift

    def get_scale(
        self,
        type_map: Optional[Dict[int, int]] = None,
        modal_map: Optional[Dict[str, int]] = None,
    ) -> Union[List[float], Dict[str, List[float]]]:
        """
        Nothing is given: return as it is
        type_map is given but not modal wise scale: return univ scale
        both type_map and modal_map is given and modal wise scale: return fully
            resolved modalwise univ scale
        """
        scale = self.scale.detach().cpu().tolist()
        if type_map and not self.use_modal_wise_scale:
            scale = _as_univ(scale, type_map, 0.0)
        elif self.use_modal_wise_scale and modal_map and type_map:
            scale = [_as_univ(s, type_map, 0.0) for s in scale]
            scale = {modal: scale[idx] for modal, idx in modal_map.items()}
        return scale

    @staticmethod
    def from_mappers(
        shift: Union[float, List[float], Dict[str, Any]],
        scale: Union[float, List[float], Dict[str, Any]],
        use_modal_wise_shift: bool,
        use_modal_wise_scale: bool,
        type_map: Dict[int, int],
        modal_map: Dict[str, int],
        **kwargs,
    ):
        """
        Fit dimensions or mapping raw shift scale values to that is valid under
        the given type_map: (atomic_numbers -> type_indices)
        If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT
        , assume it is element-wise list
        otherwise, it is modal-wise list
        """

        def solve_mapper(arr, map):
            # value is attr index and never overlap, key is either 'z' or modal str
            return [arr[z] for z in sorted(map, key=lambda x: map[x])]

        shift_scale = []
        n_atom_types = len(type_map)
        n_modals = len(modal_map)

        for s, use_mw in (
            (shift, use_modal_wise_shift),
            (scale, use_modal_wise_scale),
        ):
            # solve elemewise, or broadcast
            if isinstance(s, float):
                # given, modal-wise: no, elem-wise: no => broadcast
                shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,)
                res = torch.full(shape, s).tolist()  # TODO: w/o torch
            elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT:
                # given, modal-wise: no, elem-wise: yes(univ) => solve elem map
                s = solve_mapper(s, type_map)
                res = [s] * n_modals if use_mw else s
            elif (  # given, modal-wise: yes, elem-wise: no => broadcast to elemwise
                isinstance(s, list)
                and isinstance(s[0], float)
                and len(s) == n_modals
                and use_mw
            ):
                res = [[v] * n_atom_types for v in s]
            elif (  # given, modal-wise: no, elem-wise: yes => as it is
                isinstance(s, list)
                and isinstance(s[0], float)
                and len(s) == n_atom_types
                and not use_mw
            ):
                res = s
            elif (  # given, modal-wise: yes, elem-wise: yes => as it is
                isinstance(s, list)
                and isinstance(s[0], list)
                and len(s) == n_modals
                and len(s[0]) == n_atom_types
                and use_mw
            ):
                res = s
            elif isinstance(s, dict) and use_mw:
                # solve modal dict, modal-wise: yes
                s = solve_mapper(s, modal_map)
                res = []
                for v in s:
                    if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT:
                        # elem-wise: yes(univ) => solve elem map
                        v = solve_mapper(v, type_map)
                    elif isinstance(v, float):
                        # elem-wise: no => broadcast to elemwise
                        v = [v] * n_atom_types
                    else:
                        raise ValueError(f'Invalid shift or scale {s}')
                    res.append(v)
            else:
                raise ValueError(f'Invalid shift or scale {s}')

            if use_mw:
                assert (
                    isinstance(res, list)
                    and isinstance(res[0], list)
                    and len(res) == n_modals
                )
                assert all([len(r) == n_atom_types for r in res])  # type: ignore
            else:
                assert (
                    isinstance(res, list)
                    and isinstance(res[0], float)
                    and len(res) == n_atom_types
                )
            shift_scale.append(res)
        shift, scale = shift_scale

        return ModalWiseRescale(
            shift,
            scale,
            use_modal_wise_shift=use_modal_wise_shift,
            use_modal_wise_scale=use_modal_wise_scale,
            **kwargs,
        )

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        if self._is_batch_data:
            batch = data[KEY.BATCH]
            modal_indices = data[self.key_modal_indices][batch]
        else:
            modal_indices = data[self.key_modal_indices]
        atom_indices = data[self.key_atom_indices]
        shift = (
            self.shift[modal_indices, atom_indices]
            if self.use_modal_wise_shift
            else self.shift[atom_indices]
        )
        scale = (
            self.scale[modal_indices, atom_indices]
            if self.use_modal_wise_scale
            else self.scale[atom_indices]
        )
        data[self.key_output] = data[self.key_input] * scale.view(
            -1, 1
        ) + shift.view(-1, 1)

        return data


def get_resolved_shift_scale(
    module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale],
    type_map: Optional[Dict[int, int]] = None,
    modal_map: Optional[Dict[str, int]] = None,
):
    """
    Return resolved shift and scale from scale modules. For element wise case,
    convert to list of floats where idx is atomic number. For modal wise case, return
    dictionary of shift scale where key is modal name given in modal_map

    Return:
        Tuple of solved shift and scale
    """

    if isinstance(module, Rescale):
        return (module.get_shift(), module.get_scale())
    elif isinstance(module, SpeciesWiseRescale):
        return (module.get_shift(type_map), module.get_scale(type_map))
    elif isinstance(module, ModalWiseRescale):
        return (
            module.get_shift(type_map, modal_map),
            module.get_scale(type_map, modal_map),
        )
    raise ValueError('Not scale module')