run_pretrained_openfold.py 16.4 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20

21
22
23
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
    update_timings, relax_protein

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import pickle
29

30
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33
import time
import torch

34
35
36
37
38
39
40
41
42
43
44
45
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
    torch_major_version > 1 or 
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

46
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
from openfold.model.model import AlphaFold
49
from openfold.model.torchscript import script_preset_
50
from openfold.data import templates, feature_pipeline, data_pipeline
51
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
import openfold.np.relax.relax as relax
53

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
    tensor_tree_map,
)
57
58
59
60
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
from scripts.utils import add_data_args
62

63

64
TRACING_INTERVAL = 50
65
66


Christina Floristean's avatar
Christina Floristean committed
67
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
69
70
71
72
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

Christina Floristean's avatar
Christina Floristean committed
73
74
75
76
77
78
79
        if is_multimer:
            local_alignment_dir = alignment_dir
        else:
            local_alignment_dir = os.path.join(
                alignment_dir,
                os.path.join(alignment_dir, tag),
            )
80
        if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
81
            logger.info(f"Generating alignments for {tag}...")
82
83
                
            os.makedirs(local_alignment_dir)
84

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
91
92
93
94
            alignment_runner = data_pipeline.AlignmentRunner(
                jackhmmer_binary_path=args.jackhmmer_binary_path,
                hhblits_binary_path=args.hhblits_binary_path,
                uniref90_database_path=args.uniref90_database_path,
                mgnify_database_path=args.mgnify_database_path,
                bfd_database_path=args.bfd_database_path,
                uniclust30_database_path=args.uniclust30_database_path,
                no_cpus=args.cpus,
            )
            alignment_runner.run(
95
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
            )
97
98
99
100
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


106
107
108
109
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


110
111
112
113
114
115
116
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
117
118
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
119
        tag = tags[0]
120
121
122
123
124
125
126
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
        )
Christina Floristean's avatar
Christina Floristean committed
128
129
130
131
132
133
134
135
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
136
137
138
139
140
141
142
143
144
145
146
147
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

148
    return feature_dict
149

150
151
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
152

153

154
def main(args):
Christina Floristean's avatar
Christina Floristean committed
155
# Create the output directory
156
157
    os.makedirs(args.output_dir, exist_ok=True)

158
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
159

Christina Floristean's avatar
Christina Floristean committed
160
161
    if (args.trace_model):
        if (not config.data.predict.fixed_size):
162
163
164
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
165
166

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
    data_processor = data_pipeline.DataPipeline(
221
222
223
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
227
228
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

229
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
    random_seed = args.data_random_seed
231
    if random_seed is None:
232
        random_seed = random.randrange(2**32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
    
234
235
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
236
    
237
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
238
239
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
240
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
242
243
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244

245
246
    tag_list = []
    seq_list = []
247
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
249
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
251
        with open(fasta_path, "r") as fp:
            data = fp.read()
252
253
    
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

Christina Floristean's avatar
Christina Floristean committed
255
        if ((not is_multimer) and len(tags) != 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
257
258
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260
            continue
Christina Floristean's avatar
Christina Floristean committed
261

262
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
263
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264

265
        tag_list.append((tag, tags))
266
267
268
269
270
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
271
272
273
274
275
276
277
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
    for model, output_directory in model_generator:
278
        cur_tracing_interval = 0
279
        for (tag, tags), seqs in sorted_targets:
280
281
282
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
283
    
284
            # Does nothing if the alignments have already been computed
Christina Floristean's avatar
Christina Floristean committed
285
            precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
        
287
288
289
290
291
292
293
294
            feature_dict = feature_dicts.get(tag, None)
            if(feature_dict is None):
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
                )

297
298
299
300
301
302
                if(args.trace_model):
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303

304
305
306
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
307
                feature_dict, mode='predict', is_multimer=is_multimer
308
309
310
            )

            processed_feature_dict = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311
                k:torch.as_tensor(v, device=args.model_device) 
312
                for k,v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
            }
314

Sam DeLuca's avatar
Sam DeLuca committed
315

Christina Floristean's avatar
Christina Floristean committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        if (args.trace_model):
            if (rounded_seqlen > cur_tracing_interval):
                logger.info(
                    f"Tracing model at {rounded_seqlen} residues..."
                )
                t = time.perf_counter()
                trace_model_(model, processed_feature_dict)
                tracing_time = time.perf_counter() - t
                logger.info(
                    f"Tracing time: {tracing_time}"
                )
                cur_tracing_interval = rounded_seqlen

        out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
330

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
        # Toss out the recycling dimensions --- we don't need them anymore
Christina Floristean's avatar
Christina Floristean committed
332
333
334
        processed_feature_dict = tensor_tree_map(
            lambda x: np.array(x[..., -1].cpu()),
            processed_feature_dict
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
        )
Christina Floristean's avatar
Christina Floristean committed
336
337
338
339
340
341
342
343
344
        out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
        unrelaxed_protein = prep_output(
            out,
            processed_feature_dict,
            feature_dict,
            feature_processor,
            args.config_preset,
            args.multimer_ri_gap,
            args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
347

        unrelaxed_output_path = os.path.join(
Christina Floristean's avatar
Christina Floristean committed
348
            output_directory, f'{output_name}_unrelaxed.pdb'
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349
350
        )

Christina Floristean's avatar
Christina Floristean committed
351
352
        with open(unrelaxed_output_path, 'w') as fp:
            fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353

Christina Floristean's avatar
Christina Floristean committed
354
        logger.info(f"Output written to {unrelaxed_output_path}...")
355

Christina Floristean's avatar
Christina Floristean committed
356
357
358
359
        if not args.skip_relaxation:
            # Relax the prediction.
            logger.info(f"Running relaxation on {unrelaxed_output_path}...")
            relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
360

Christina Floristean's avatar
Christina Floristean committed
361
        if args.save_outputs:
362
            output_dict_path = os.path.join(
Christina Floristean's avatar
Christina Floristean committed
363
                output_directory, f'{output_name}_output_dict.pkl'
364
365
366
367
            )
            with open(output_dict_path, "wb") as fp:
                pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)

Christina Floristean's avatar
Christina Floristean committed
368
            logger.info(f"Model output written to {output_dict_path}...")
369
370
371
372


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
373
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375
        help="Path to directory containing FASTA files, one sequence per file"
376
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377
378
379
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
380
381
382
383
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
384
    )
385
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386
387
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
388
389
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
391
392
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
393
394
    )
    parser.add_argument(
395
        "--config_preset", type=str, default="model_1",
396
        help="""Name of a model config preset defined in openfold/config.py"""
397
398
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400
401
402
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
403
404
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
406
407
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
408
    )
409
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410
        "--save_outputs", action="store_true", default=False,
411
412
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
414
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
    )
417
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
        "--preset", type=str, default='full_dbs',
419
420
421
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
423
424
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
425
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
427
428
429
        "--data_random_seed", type=str, default=None
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
430
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
432
433
434
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
435
436
437
438
439
440
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
441
442
443
444
445
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
446
447
448
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
449
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450
    add_data_args(parser)
451
452
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453
454
    if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
        args.jax_param_path = os.path.join(
455
            "openfold", "resources", "params", 
456
            "params_" + args.config_preset + ".npz"
457
458
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459
460
461
462
463
464
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

465
    main(args)