convert_openfold_to_unifold.py 2.36 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import sys

def openfold2unifold(model_states):
    new_model_states = {}
    mul_projs = {}
    mul_gates = {}
    for key, value in model_states.items():
        new_key = key
        if "msa_att_col._msa_att" in key:
            new_key = new_key.replace("msa_att_col._msa_att", "msa_att_col")
        if "extra_msa_stack.stack" in key:
            new_key = new_key.replace("extra_msa_stack.stack", "extra_msa_stack")
        if "tri_mul" in key:
            if "linear_a_p" in key or "linear_b_p" in key:
                new_key = key.replace("linear_a_p", "linear_ab_p").replace(
                    "linear_b_p", "linear_ab_p"
                )
                mul_projs[new_key] = 1
                continue
            if "linear_a_g" in key or "linear_b_g" in key:
                new_key = key.replace("linear_a_g", "linear_ab_g").replace(
                    "linear_b_g", "linear_ab_g"
                )
                mul_gates[new_key] = 1
                continue
        if ".tm." in key:
            new_key = new_key.replace(".tm.", ".pae.")
        if ".core." in key:
            new_key = new_key.replace("core." ,"")
        new_model_states[new_key] = value

    for key in mul_projs:
        new_key = key
        k1 = key.replace("linear_ab_p", "linear_a_p")
        k2 = key.replace("linear_ab_p", "linear_b_p")
        weight = torch.cat([model_states[k1], model_states[k2]], dim=0)
        if ".core." in key:
            new_key = new_key.replace("core." ,"")
        new_model_states[new_key] = weight

    for key in mul_gates:
        new_key = key
        k1 = key.replace("linear_ab_g", "linear_a_g")
        k2 = key.replace("linear_ab_g", "linear_b_g")
        weight = torch.cat([model_states[k1], model_states[k2]], dim=0)
        if ".core." in key:
            new_key = new_key.replace("core." ,"")
        new_model_states[new_key] = weight

    return new_model_states


load_ckpt=sys.argv[1]
save_ckpt=sys.argv[2]
state_dict = torch.load(load_ckpt)
state_dict = openfold2unifold(state_dict)
save_state_dict = {}
save_state_dict["ema"] = {}
save_state_dict["extra_state"] = {}
save_state_dict["extra_state"]["train_iterator"] = {}
save_state_dict["extra_state"]["train_iterator"]["epoch"] = 1
update_state_dict = {"model." + k:state_dict[k] for k in state_dict}
save_state_dict["ema"]["params"] = update_state_dict
torch.save(save_state_dict, save_ckpt)