run_pretrained_openfold.py 11.3 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from datetime import date
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
import logging
19
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
import pickle
23
24
import random
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
26
27
import time
import torch

28
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
32
33
34
from openfold.data import (
    data_pipeline,
    feature_pipeline, 
    templates, 
)
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
from openfold.model.model import AlphaFold
36
from openfold.model.torchscript import script_preset_
37
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
45
    tensor_tree_map,
)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
from scripts.utils import add_data_args
47

48

49
50
def main(args):
    config = model_config(args.model_name)
51
    model = AlphaFold(config)
52
    model = model.eval()
53
    import_jax_weights_(model, args.param_path, version=args.model_name)
54
    #script_preset_(model)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
    model = model.to(args.model_device)
56

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    is_multimer = "multimer" in args.model_name

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
    data_processor = data_pipeline.DataPipeline(
112
113
114
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
117
118
119
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

120
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
    random_seed = args.data_random_seed
122
123
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
125
126
127
128
    
    feature_processor = feature_pipeline.FeaturePipeline(
        config.data
    )
    
129
130
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131
    if(not args.use_precomputed_alignments):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
133
134
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    for fasta_path in os.listdir(args.fasta_dir):
        if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
            print(f"Skipping {fasta_path}. Not a .fasta file...")
            continue
   
        fasta_path = os.path.join(args.fasta_dir, fasta_path)

        # Gather input sequences
        with open(fasta_path, "r") as fp:
            data = fp.read()

        lines = [
            l.replace('\n', '') 
            for prot in data.split('>') for l in prot.strip().split('\n', 1)
        ][1:]
        tags, seqs = lines[::2], lines[1::2]

        if((not is_multimer) and len(tags) != 1):
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            continue
        
        for tag, seq in zip(tags, seqs):
            tag, seq = tags[0], seqs[0]
            local_alignment_dir = os.path.join(alignment_dir, tag)
            if(args.use_precomputed_alignments is None):
                if not os.path.exists(local_alignment_dir):
                    os.makedirs(local_alignment_dir)
                
                alignment_runner.run(
                    fasta_path, local_alignment_dir
                )
       
        if(is_multimer):
            local_alignment_dir = alignment_dir
        else:
            local_alignment_dir = os.path.join(
                alignment_dir,
                tags[0],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
178

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179
180
181
        feature_dict = data_processor.process_fasta(
            fasta_path=fasta_path, alignment_dir=local_alignment_dir
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
183
        processed_feature_dict = feature_processor.process_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
            feature_dict, mode='predict', is_multimer=is_multimer,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
187
188
189
190
191
192
193
194
        logging.info("Executing model...")
        batch = processed_feature_dict
        with torch.no_grad():
            batch = {
                k:torch.as_tensor(v, device=args.model_device) 
                for k,v in batch.items()
            }
        
195
            t = time.perf_counter()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196
197
198
199
200
201
202
203
             
            chunk_size = model.globals.chunk_size
            try:
                model.globals.chunk_size = None
                out = model(batch)
            except RuntimeError as e:
                model.globals.chunk_size = chunk_size
                out = model(batch)
204
            logging.info(f"Inference time: {time.perf_counter() - t}")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206
207
208
209
210
211
212
213
214
215
        # Toss out the recycling dimensions --- we don't need them anymore
        batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
        out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
        
        plddt = out["plddt"]
        mean_plddt = np.mean(plddt)
        
        plddt_b_factors = np.repeat(
            plddt[..., None], residue_constants.atom_type_num, axis=-1
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
218
219
        unrelaxed_protein = protein.from_prediction(
            features=batch,
            result=out,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
221
            b_factors=plddt_b_factors,
            remove_leading_feature_dimension=not is_multimer,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
224
225
226
227
228
229
230

        # Save the unrelaxed PDB.
        unrelaxed_output_path = os.path.join(
            args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb'
        )
        with open(unrelaxed_output_path, 'w') as f:
            f.write(protein.to_pdb(unrelaxed_protein))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
232
233
        print(unrelaxed_output_path)
        print("asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234
        amber_relaxer = relax.AmberRelaxation(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
235
236
            use_gpu=(args.model_device != "cpu"),
            **config.relax,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237
238
239
        )
        
        # Relax the prediction.
240
        t = time.perf_counter()
241
        visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
244
        if("cuda" in args.model_device):
            device_no = args.model_device.split(":")[-1]
            os.environ["CUDA_VISIBLE_DEVICES"] = device_no
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246
        os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
247
        logging.info(f"Relaxation time: {time.perf_counter() - t}")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
249
250
        
        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251
            args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
253
254
        )
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)
255

256
257
258
259
260
261
262
        if(args.save_outputs):
            output_dict_path = os.path.join(
                args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
            )
            with open(output_dict_path, "wb") as fp:
                pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)

263
264
265

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
266
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
        "fasta_dir", type=str,
268
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269
270
271
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
272
273
274
275
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
276
    )
277
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
279
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
280
281
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
284
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
285
286
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
288
289
        "--model_name", type=str, default="model_1",
        help="""Name of a model config. Choose one of model_{1-5} or 
             model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
290
291
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
292
293
294
295
        "--param_path", type=str, default=None,
        help="""Path to model parameters. If None, parameters are selected
             automatically according to the model name from 
             openfold/resources/params"""
296
    )
297
298
299
300
    parser.add_argument(
        "--save_outputs", type=bool, default=False,
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
301
302
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
    )
305
    parser.add_argument(
306
        '--preset', type=str, default='full_dbs',
307
308
309
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310
        '--data_random_seed', type=str, default=None
311
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312
    add_data_args(parser)
313
314
315
316
317
318
319
320
    args = parser.parse_args()

    if(args.param_path is None):
        args.param_path = os.path.join(
            "openfold", "resources", "params", 
            "params_" + args.model_name + ".npz"
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
321
322
323
324
325
326
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

327
    main(args)