run_pretrained_openfold.py 11 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from datetime import date
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
import logging
19
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
import pickle
23
24
import random
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
26
27
import time
import torch

28
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
32
33
34
from openfold.data import (
    data_pipeline,
    feature_pipeline, 
    templates, 
)
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
from openfold.model.model import AlphaFold
36
from openfold.model.torchscript import script_preset_
37
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
45
    tensor_tree_map,
)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
from scripts.utils import add_data_args
47

48

49
50
def main(args):
    config = model_config(args.model_name)
51
    model = AlphaFold(config)
52
    model = model.eval()
53
    import_jax_weights_(model, args.param_path, version=args.model_name)
54
    #script_preset_(model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
    model = model.to(args.model_device)
56

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    is_multimer = "multimer" in args.model_name

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
    data_processor = data_pipeline.DataPipeline(
112
113
114
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
117
118
119
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

120
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
    random_seed = args.data_random_seed
122
123
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
125
126
127
128
    
    feature_processor = feature_pipeline.FeaturePipeline(
        config.data
    )
    
129
130
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131
    if(not args.use_precomputed_alignments):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
133
134
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    for fasta_path in os.listdir(args.fasta_dir):
        if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
            print(f"Skipping {fasta_path}. Not a .fasta file...")
            continue
   
        fasta_path = os.path.join(args.fasta_dir, fasta_path)

        # Gather input sequences
        with open(fasta_path, "r") as fp:
            data = fp.read()

        lines = [
            l.replace('\n', '') 
            for prot in data.split('>') for l in prot.strip().split('\n', 1)
        ][1:]
        tags, seqs = lines[::2], lines[1::2]

        if((not is_multimer) and len(tags) != 1):
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            continue
        
        for tag, seq in zip(tags, seqs):
            tag, seq = tags[0], seqs[0]
            local_alignment_dir = os.path.join(alignment_dir, tag)
            if(args.use_precomputed_alignments is None):
                if not os.path.exists(local_alignment_dir):
                    os.makedirs(local_alignment_dir)
                
                alignment_runner.run(
                    fasta_path, local_alignment_dir
                )
       
        if(is_multimer):
            local_alignment_dir = alignment_dir
        else:
            local_alignment_dir = os.path.join(
                alignment_dir,
                tags[0],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
178

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179
180
181
        feature_dict = data_processor.process_fasta(
            fasta_path=fasta_path, alignment_dir=local_alignment_dir
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
183
        processed_feature_dict = feature_processor.process_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
            feature_dict, mode='predict', is_multimer=is_multimer,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
187
188
189
190
191
192
193
194
        logging.info("Executing model...")
        batch = processed_feature_dict
        with torch.no_grad():
            batch = {
                k:torch.as_tensor(v, device=args.model_device) 
                for k,v in batch.items()
            }
        
195
            t = time.perf_counter()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196
197
198
199
200
201
202
203
             
            chunk_size = model.globals.chunk_size
            try:
                model.globals.chunk_size = None
                out = model(batch)
            except RuntimeError as e:
                model.globals.chunk_size = chunk_size
                out = model(batch)
204
            logging.info(f"Inference time: {time.perf_counter() - t}")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206
207
208
209
210
211
212
213
214
215
        # Toss out the recycling dimensions --- we don't need them anymore
        batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
        out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
        
        plddt = out["plddt"]
        mean_plddt = np.mean(plddt)
        
        plddt_b_factors = np.repeat(
            plddt[..., None], residue_constants.atom_type_num, axis=-1
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
218
219
220
221
        unrelaxed_protein = protein.from_prediction(
            features=batch,
            result=out,
            b_factors=plddt_b_factors
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
223
224
225
226
227
228
229

        # Save the unrelaxed PDB.
        unrelaxed_output_path = os.path.join(
            args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb'
        )
        with open(unrelaxed_output_path, 'w') as f:
            f.write(protein.to_pdb(unrelaxed_protein))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
        amber_relaxer = relax.AmberRelaxation(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
232
            use_gpu=(args.model_device != "cpu"),
            **config.relax,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
234
235
        )
        
        # Relax the prediction.
236
        t = time.perf_counter()
237
        visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
239
240
        if("cuda" in args.model_device):
            device_no = args.model_device.split(":")[-1]
            os.environ["CUDA_VISIBLE_DEVICES"] = device_no
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
        os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
243
        logging.info(f"Relaxation time: {time.perf_counter() - t}")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244
245
246
        
        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
            args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
249
250
        )
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)
251

252
253
254
255
256
257
258
        if(args.save_outputs):
            output_dict_path = os.path.join(
                args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
            )
            with open(output_dict_path, "wb") as fp:
                pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)

259
260
261

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
262
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
        "fasta_dir", type=str,
264
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265
266
267
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
268
269
270
271
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
272
    )
273
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
275
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
276
277
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279
280
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
281
282
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
284
285
        "--model_name", type=str, default="model_1",
        help="""Name of a model config. Choose one of model_{1-5} or 
             model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
286
287
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
289
290
291
        "--param_path", type=str, default=None,
        help="""Path to model parameters. If None, parameters are selected
             automatically according to the model name from 
             openfold/resources/params"""
292
    )
293
294
295
296
    parser.add_argument(
        "--save_outputs", type=bool, default=False,
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297
298
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
299
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
300
    )
301
    parser.add_argument(
302
        '--preset', type=str, default='full_dbs',
303
304
305
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
        '--data_random_seed', type=str, default=None
307
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
308
    add_data_args(parser)
309
310
311
312
313
314
315
316
    args = parser.parse_args()

    if(args.param_path is None):
        args.param_path = os.path.join(
            "openfold", "resources", "params", 
            "params_" + args.model_name + ".npz"
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
317
318
319
320
321
322
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

323
    main(args)