convert_alphafold_to_unifold.py 767 Bytes
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import sys

from unifold.config import model_config
from unifold.modules.alphafold import AlphaFold

from scripts.translate_jax_params import (
    import_jax_weights_,
)

load_ckpt=sys.argv[1]
save_ckpt=sys.argv[2]
model_name = sys.argv[3]

config = model_config(model_name)
model = AlphaFold(config)
import_jax_weights_(model, load_ckpt, version=model_name)
state_dict = model.state_dict()
save_state_dict = {}
save_state_dict["ema"] = {}
save_state_dict["extra_state"] = {}
save_state_dict["extra_state"]["train_iterator"] = {}
save_state_dict["extra_state"]["train_iterator"]["epoch"] = 1
update_state_dict = {"model." + k:state_dict[k] for k in state_dict}
save_state_dict["ema"]["params"] = update_state_dict
torch.save(save_state_dict, save_ckpt)