deploy.py 5.18 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
from datetime import datetime
from typing import Optional

import e3nn.util.jit
import torch
import torch.nn
from ase.data import chemical_symbols

import sevenn._keys as KEY
from sevenn import __version__
from sevenn.model_build import build_E3_equivariant_model
from sevenn.util import load_checkpoint


def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None):
    """
    This method is messy to avoid changes in pair_e3gnn.cpp, while
    refactoring python part.
    If changes the behavior, and accordingly pair_e3gnn.cpp,
    we have to recompile LAMMPS (which I always want to procrastinate)
    """
    from sevenn.nn.edge_embedding import EdgePreprocess
    from sevenn.nn.force_output import ForceStressOutput

    cp = load_checkpoint(checkpoint)
    model, config = cp.build_model('e3nn'), cp.config

    model.prepand_module('edge_preprocess', EdgePreprocess(True))
    grad_module = ForceStressOutput()
    model.replace_module('force_output', grad_module)
    new_grad_key = grad_module.get_grad_key()
    model.key_grad = new_grad_key
    if hasattr(model, 'eval_type_map'):
        setattr(model, 'eval_type_map', False)

    if modal:
        model.prepare_modal_deploy(modal)
    elif model.modal_map is not None and len(model.modal_map) >= 1:
        raise ValueError(
            f'Modal is not given. It has: {list(model.modal_map.keys())}'
        )

    model.set_is_batch_data(False)
    model.eval()

    model = e3nn.util.jit.script(model)
    model = torch.jit.freeze(model)

    # make some config need for md
    md_configs = {}
    type_map = config[KEY.TYPE_MAP]
    chem_list = ''
    for Z in type_map.keys():
        chem_list += chemical_symbols[Z] + ' '
    chem_list.strip()
    md_configs.update({'chemical_symbols_to_index': chem_list})
    md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
    md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
    md_configs.update(
        {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
    )
    md_configs.update({'version': __version__})
    md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
    md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})

    if fname.endswith('.pt') is False:
        fname += '.pt'
    torch.jit.save(model, fname, _extra_files=md_configs)


# TODO: build model only once
def deploy_parallel(
    checkpoint, fname='deployed_parallel', modal: Optional[str] = None
):
    # Additional layer for ghost atom (and copy parameters from original)
    GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1']

    cp = load_checkpoint(checkpoint)
    model, config = cp.build_model('e3nn'), cp.config
    config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False}
    model_state_dct = model.state_dict()

    model_list = build_E3_equivariant_model(config, parallel=True)
    dct_temp = {}
    copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS}
    for ghost_layer_key in GHOST_LAYERS_KEYS:
        for key, val in model_state_dct.items():
            if not key.startswith(ghost_layer_key):
                continue
            dct_temp.update({f'ghost_{key}': val})
            copy_counter[ghost_layer_key] += 1
    # Ensure reference weights are copied from state dict
    assert all(x > 0 for x in copy_counter.values())

    model_state_dct.update(dct_temp)

    for model_part in model_list:
        missing, _ = model_part.load_state_dict(model_state_dct, strict=False)
        if hasattr(model_part, 'eval_type_map'):
            setattr(model_part, 'eval_type_map', False)
        # Ensure all values are inserted
        assert len(missing) == 0, missing

    if modal:
        model_list[0].prepare_modal_deploy(modal)
    elif model_list[0].modal_map is not None:
        raise ValueError(
            f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}'
        )

    # prepare some extra information for MD
    md_configs = {}
    type_map = config[KEY.TYPE_MAP]

    chem_list = ''
    for Z in type_map.keys():
        chem_list += chemical_symbols[Z] + ' '
    chem_list.strip()

    comm_size = max(
        [
            seg._modules[f'{t}_convolution']._comm_size  # type: ignore
            for t, seg in enumerate(model_list)
        ]
    )

    md_configs.update({'chemical_symbols_to_index': chem_list})
    md_configs.update({'cutoff': str(config[KEY.CUTOFF])})
    md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])})
    md_configs.update({'comm_size': str(comm_size)})
    md_configs.update(
        {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')}
    )
    md_configs.update({'version': __version__})
    md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')})
    md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')})

    os.makedirs(fname)
    for idx, model in enumerate(model_list):
        fname_full = f'{fname}/deployed_parallel_{idx}.pt'
        model.set_is_batch_data(False)
        model.eval()

        model = e3nn.util.jit.script(model)
        model = torch.jit.freeze(model)

        torch.jit.save(model, fname_full, _extra_files=md_configs)