run_pretrained_openfold.py 6.64 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from datetime import date
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
import logging
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
21

# A hack to get OpenMM and PyTorch to peacefully coexist
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
23
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
import pickle
25
26
27
import random
import sys

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
from openfold.features import templates, feature_pipeline, data_pipeline
29

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
import time
31
32

import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
34
import torch

35
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
from openfold.model.model import AlphaFold
37
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
45
    tensor_tree_map,
)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
from scripts.utils import add_data_args
47

48
49
50
51
52
def main(args):
    config = model_config(args.model_name)
    model = AlphaFold(config.model)
    model = model.eval()
    import_jax_weights_(model, args.param_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
    model = model.to(args.model_device)
54
    
55
56
57
58
59
60
    # FEATURE COLLECTION AND PROCESSING
    num_ensemble = 1

    template_featurizer = templates.TemplateHitFeaturizer(
        mmcif_dir=args.template_mmcif_dir,
        max_template_date=args.max_template_date,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
        max_hits=args.max_template_hits,
62
63
        kalign_binary_path=args.kalign_binary_path,
        release_dates_path=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64
65
        obsolete_pdbs_path=args.obsolete_pdbs_path
    )
66

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
68
    use_small_bfd=(args.bfd_database_path is None)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
    alignment_runner = data_pipeline.AlignmentRunner(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
        jackhmmer_binary_path=args.jackhmmer_binary_path,
71
72
73
74
75
76
77
78
        hhblits_binary_path=args.hhblits_binary_path,
        hhsearch_binary_path=args.hhsearch_binary_path,
        uniref90_database_path=args.uniref90_database_path,
        mgnify_database_path=args.mgnify_database_path,
        bfd_database_path=args.bfd_database_path,
        uniclust30_database_path=args.uniclust30_database_path,
        small_bfd_database_path=args.small_bfd_database_path,
        pdb70_database_path=args.pdb70_database_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
        use_small_bfd=use_small_bfd,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
        no_cpus=args.cpus,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
82
83
    )

    data_processor = data_pipeline.DataPipeline(
84
85
86
87
88
        template_featurizer=template_featurizer,
        use_small_bfd=use_small_bfd
    )

    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
    random_seed = args.data_random_seed
90
91
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
    config.data.predict.num_ensemble = num_ensemble
93
94
95
    feature_processor = feature_pipeline.FeaturePipeline(config)
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
97
98
    alignment_dir = os.path.join(output_dir_base, "alignments")
    if not os.path.exists(alignment_dir):
        os.makedirs(alignment_dir)
99

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
    logging.info("Generating features...")
101
    alignment_runner.run(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
        args.fasta_path, alignment_dir
    )     

    feature_dict = data_processor.process_fasta(
106
        fasta_path=args.fasta_path, alignment_dir=alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
    )
108

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109
    processed_feature_dict = feature_processor.process_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
        feature_dict, mode='predict',
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
    )
112

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
    logging.info("Executing model...")
114
    batch = processed_feature_dict
115
116
    with torch.no_grad():
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
            k:torch.as_tensor(v, device=args.model_device) 
118
119
120
121
122
            for k,v in batch.items()
        }
    
        t = time.time()
        out = model(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123
        logging.info(f"Inference time: {time.time() - t}")
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    
    # Toss out the recycling dimensions --- we don't need them anymore
    batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
    out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
    
    plddt = out["plddt"]
    mean_plddt = np.mean(plddt)
    
    plddt_b_factors = np.repeat(
        plddt[..., None], residue_constants.atom_type_num, axis=-1
    )
    
    unrelaxed_protein = protein.from_prediction(
        features=batch,
        result=out,
        b_factors=plddt_b_factors
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
     
142
143
144
145
146
    amber_relaxer = relax.AmberRelaxation(
        **config.relax
    )
    
    # Relax the prediction.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
147
    t = time.time()
148
    relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
    logging.info(f"Relaxation time: {time.time() - t}")
150
151
152
153
154
155
156
157
158
159
160
    
    # Save the relaxed PDB.
    relaxed_output_path = os.path.join(
        args.output_dir, f'relaxed_{args.model_name}.pdb'
    )
    with open(relaxed_output_path, 'w') as f:
        f.write(relaxed_pdb_str)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
161
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
        "fasta_path", type=str,
163
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
    add_data_args(parser)
165
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167
168
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
        required=True
169
170
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
173
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
174
175
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176
177
178
        "--model_name", type=str, default="model_1",
        help="""Name of a model config. Choose one of model_{1-5} or 
             model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
179
180
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181
182
183
184
        "--param_path", type=str, default=None,
        help="""Path to model parameters. If None, parameters are selected
             automatically according to the model name from 
             openfold/resources/params"""
185
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
187
188
189
    parser.add_argument(
        "--cpus", type=int, default=4,
        help="""Number of CPUs to use to run alignment tools"""
    )
190
    parser.add_argument(
191
        '--preset', type=str, default='full_dbs',
192
193
194
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
195
        '--data_random_seed', type=str, default=None
196
    )
197
198
199
200
201
202
203
204
205

    args = parser.parse_args()

    if(args.param_path is None):
        args.param_path = os.path.join(
            "openfold", "resources", "params", 
            "params_" + args.model_name + ".npz"
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206
207
208
209
210
211
212
    if(args.bfd_database_path is None and 
       args.small_bfd_database_path is None):
        raise ValueError(
            "At least one of --bfd_database_path or --small_bfd_database_path"
            "must be specified"
        )

213
    main(args)