util.py 11.2 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import os
import os.path as osp
import pathlib
import shutil
from typing import Dict, List, Tuple, Union

import numpy as np
import requests
import torch
import torch.nn
from e3nn.o3 import FullTensorProduct, Irreps
from tqdm import tqdm

import sevenn._const as _const
import sevenn._keys as KEY


def to_atom_graph_list(atom_graph_batch):
    """
    torch_geometric batched data to separate list
    original to_data_list() by PyG is not enough since
    it doesn't handle inferred tensors
    """
    is_stress = KEY.PRED_STRESS in atom_graph_batch

    data_list = atom_graph_batch.to_data_list()

    indices = atom_graph_batch[KEY.NUM_ATOMS].tolist()

    atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices)
    inferred_total_energy_list = torch.unbind(
        atom_graph_batch[KEY.PRED_TOTAL_ENERGY]
    )
    inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices)

    inferred_stress_list = None
    if is_stress:
        inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS])

    for i, data in enumerate(data_list):
        data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i]
        data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i]
        data[KEY.PRED_FORCE] = inferred_force_list[i]
        # To fit with KEY.STRESS (ref) format
        if is_stress and inferred_stress_list is not None:
            data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0)
    return data_list


def error_recorder_from_loss_functions(loss_functions):
    from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type
    from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss

    metrics = []
    for loss_function, _ in loss_functions:
        ref_key = loss_function.ref_key
        pred_key = loss_function.pred_key
        # unit = loss_function.unit
        criterion = loss_function.criterion
        name = loss_function.name
        base = None
        if type(loss_function) is PerAtomEnergyLoss:
            base = get_err_type('Energy')
        elif type(loss_function) is ForceLoss:
            base = get_err_type('Force')
        elif type(loss_function) is StressLoss:
            base = get_err_type('Stress')
        else:
            base = {}
        base['name'] = name
        base['ref_key'] = ref_key
        base['pred_key'] = pred_key
        if type(criterion) is torch.nn.MSELoss:
            base['name'] = base['name'] + '_RMSE'
            metrics.append(RMSError(**base))
        elif type(criterion) is torch.nn.L1Loss:
            metrics.append(MAError(**base))
    return ErrorRecorder(metrics)


def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]):
    from ase.data import chemical_symbols

    type_map_rev = {v: k for k, v in type_map.items()}
    return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices]


def model_from_checkpoint(
    checkpoint: str,
) -> Tuple[torch.nn.Module, Dict]:
    cp = load_checkpoint(checkpoint)
    model = cp.build_model()

    return model, cp.config


def model_from_checkpoint_with_backend(
    checkpoint: str,
    backend: str = 'e3nn',
) -> Tuple[torch.nn.Module, Dict]:
    cp = load_checkpoint(checkpoint)
    model = cp.build_model(backend)

    return model, cp.config


def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC):
    from .atom_graph_data import AtomGraphData
    from .train.dataload import unlabeled_atoms_to_graph

    atom_graph = AtomGraphData.from_numpy_dict(
        unlabeled_atoms_to_graph(atoms, cutoff)
    )
    atom_graph[grad_key].requires_grad_(True)
    atom_graph[KEY.BATCH] = torch.zeros([0])
    return atom_graph


def chemical_species_preprocess(input_chem: List[str], universal: bool = False):
    from ase.data import atomic_numbers, chemical_symbols

    from .nn.node_embedding import get_type_mapper_from_specie

    config = {}
    if not universal:
        input_chem = list(set(input_chem))
        chemical_specie = sorted([x.strip() for x in input_chem])
        config[KEY.CHEMICAL_SPECIES] = chemical_specie
        config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [
            atomic_numbers[x] for x in chemical_specie
        ]
        config[KEY.NUM_SPECIES] = len(chemical_specie)
        config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie)
    else:
        config[KEY.CHEMICAL_SPECIES] = chemical_symbols
        len_univ = len(chemical_symbols)
        config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ))
        config[KEY.NUM_SPECIES] = len_univ
        config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)}
    return config


def dtype_correct(
    v: Union[np.ndarray, torch.Tensor, int, float],
    float_dtype: torch.dtype = torch.float32,
    int_dtype: torch.dtype = torch.int64,
):
    if isinstance(v, np.ndarray):
        if np.issubdtype(v.dtype, np.floating):
            return torch.from_numpy(v).to(float_dtype)
        elif np.issubdtype(v.dtype, np.integer):
            return torch.from_numpy(v).to(int_dtype)
    elif isinstance(v, torch.Tensor):
        if v.dtype.is_floating_point:
            return v.to(float_dtype)  # convert to specified float dtype
        else:  # assuming non-floating point tensors are integers
            return v.to(int_dtype)  # convert to specified int dtype
    else:  # scalar values
        if isinstance(v, int):
            return torch.tensor(v, dtype=int_dtype)
        elif isinstance(v, float):
            return torch.tensor(v, dtype=float_dtype)
        else:  # Not numeric
            return v


def infer_irreps_out(
    irreps_x: Irreps,
    irreps_operand: Irreps,
    drop_l: Union[bool, int] = False,
    parity_mode: str = 'full',
    fix_multiplicity: Union[bool, int] = False,
):
    assert parity_mode in ['full', 'even', 'sph']
    # (mul, (ir, p))
    irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify()
    new_irreps_elem = []
    for mul, (l, p) in irreps_out:  # noqa
        elem = (mul, (l, p))
        if drop_l is not False and l > drop_l:
            continue
        if parity_mode == 'even' and p == -1:
            continue
        elif parity_mode == 'sph' and p != (-1) ** l:
            continue
        if fix_multiplicity:
            elem = (fix_multiplicity, (l, p))
        new_irreps_elem.append(elem)
    return Irreps(new_irreps_elem)


def download_checkpoint(path: str, url: str):
    fname = osp.basename(path)
    temp_path = path + '.partial'
    try:
        # raises permission error if fails
        os.makedirs(osp.dirname(path), exist_ok=True)
        response = requests.get(url, stream=True, timeout=30)
        response.raise_for_status()  # Raise exception for bad status codes

        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1 KB chunks

        progress_bar = tqdm(
            total=total_size,
            unit='B',
            unit_scale=True,
            desc=f'Downloading {fname}',
        )

        with open(temp_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()

        shutil.move(temp_path, path)
        print(f'Checkpoint downloaded: {path}')
        return path
    except PermissionError:
        raise
    except Exception as e:
        # Clean up partial downloads on failure
        # May not work as errors handled internally by tqdm etc.
        print(f'Download failed: {str(e)}')
        if os.path.exists(temp_path):
            print(f'Cleaning up partial download: {temp_path}')
            os.remove(temp_path)
        raise


def pretrained_name_to_path(name: str) -> str:
    name = name.lower()
    heads = ['sevennet', '7net']
    checkpoint_path = None
    url = None

    if (  # TODO: regex
        name in [f'{n}-0_11july2024' for n in heads]
        or name in [f'{n}-0_11jul2024' for n in heads]
        or name in ['sevennet-0', '7net-0']
    ):
        checkpoint_path = _const.SEVENNET_0_11Jul2024
    elif name in [f'{n}-0_22may2024' for n in heads]:
        checkpoint_path = _const.SEVENNET_0_22May2024
    elif name in [f'{n}-l3i5' for n in heads]:
        checkpoint_path = _const.SEVENNET_l3i5
    elif name in [f'{n}-mf-0' for n in heads]:
        checkpoint_path = _const.SEVENNET_MF_0
    elif name in [f'{n}-mf-ompa' for n in heads]:
        checkpoint_path = _const.SEVENNET_MF_ompa
    elif name in [f'{n}-omat' for n in heads]:
        checkpoint_path = _const.SEVENNET_omat
    else:
        raise ValueError('Not a valid pretrained model name')
    url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path)

    paths = [
        checkpoint_path,
        checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')),
    ]

    for path in paths:
        if osp.exists(path):
            return path

    # File not found check url and try download
    if url is None:
        raise FileNotFoundError(checkpoint_path)

    try:
        return download_checkpoint(paths[0], url)  # 7net package path
    except PermissionError:
        return download_checkpoint(paths[1], url)  # ~/.cache


def load_checkpoint(checkpoint: Union[pathlib.Path, str]):
    from sevenn.checkpoint import SevenNetCheckpoint
    suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat']
    if osp.isfile(checkpoint):
        checkpoint_path = checkpoint
    else:
        try:
            checkpoint_path = pretrained_name_to_path(str(checkpoint))
        except ValueError:
            raise ValueError(
                f'Given {checkpoint} is not exists and not a pre-trained name.\n'
                f'Valid pretrained model names: {suggests}'
            )
    return SevenNetCheckpoint(checkpoint_path)


def unique_filepath(filepath: str) -> str:
    if not os.path.isfile(filepath):
        return filepath
    else:
        dirname = os.path.dirname(filepath)
        fname = os.path.basename(filepath)
        name, ext = os.path.splitext(fname)
        cnt = 0
        new_name = f'{name}{cnt}{ext}'
        new_path = os.path.join(dirname, new_name)
        while os.path.exists(new_path):
            cnt += 1
            new_name = f'{name}{cnt}{ext}'
            new_path = os.path.join(dirname, new_name)
        return new_path


def get_error_recorder(
    recorder_tuples: List[Tuple[str, str]] = [
        ('Energy', 'RMSE'),
        ('Force', 'RMSE'),
        ('Stress', 'RMSE'),
        ('Energy', 'MAE'),
        ('Force', 'MAE'),
        ('Stress', 'MAE'),
    ],
):
    # TODO add criterion argument and loss recorder selections
    import sevenn.error_recorder as error_recorder

    config = recorder_tuples
    err_metrics = []
    for err_type, metric_name in config:
        metric_kwargs = error_recorder.get_err_type(err_type).copy()
        metric_kwargs['name'] += f'_{metric_name}'
        metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name]
        err_metrics.append(metric_cls(**metric_kwargs))
    return error_recorder.ErrorRecorder(err_metrics)