run_pretrained_openfold.py 18.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20

21
22
23
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
    update_timings, relax_protein

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import pickle
29

30
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33
import time
import torch

34
35
36
37
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
38
    torch_major_version > 1 or
39
40
41
42
43
44
45
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

46
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
from openfold.model.model import AlphaFold
49
from openfold.model.torchscript import script_preset_
50
from openfold.data import templates, feature_pipeline, data_pipeline
51
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
import openfold.np.relax.relax as relax
53

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
    tensor_tree_map,
)
57
58
59
60
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
61
from scripts.precompute_embeddings import EmbeddingGenerator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
from scripts.utils import add_data_args
63

64

65
TRACING_INTERVAL = 50
66
67


Christina Floristean's avatar
Christina Floristean committed
68
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
70
71
72
73
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

74
        local_alignment_dir = os.path.join(
Christina Floristean's avatar
Christina Floristean committed
75
76
77
                alignment_dir,
                os.path.join(alignment_dir, tag),
            )
78
        
79
        if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
80
            logger.info(f"Generating alignments for {tag}...")
81

82
            os.makedirs(local_alignment_dir)
83

84
            # In seqemb mode, use AlignmentRunner only to generate templates
85
86
87
88
89
90
            if args.use_single_seq_mode:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    no_cpus=args.cpus,
                )
91
                embedding_generator = EmbeddingGenerator()
92
                embedding_generator.run(tmp_fasta_path, alignment_dir)
93
94
95
96
97
98
99
            else:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    hhblits_binary_path=args.hhblits_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    mgnify_database_path=args.mgnify_database_path,
                    bfd_database_path=args.bfd_database_path,
100
                    uniref30_database_path=args.uniref30_database_path,
101
                    uniclust30_database_path=args.uniclust30_database_path,
102
                    uniprot_database_path=args.uniprot_database_path,
103
104
                    no_cpus=args.cpus,
                )
105

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
            alignment_runner.run(
107
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
            )
109
110
111
112
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
114
115
116
117

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


118
119
120
121
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


122
123
124
125
126
127
128
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
129
130
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
131
        tag = tags[0]
132
133
134
135
136
137
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
138
139
140
            fasta_path=tmp_fasta_path,
            alignment_dir=local_alignment_dir,
            seqemb_mode=args.use_single_seq_mode,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
        )
Christina Floristean's avatar
Christina Floristean committed
142
143
144
145
146
147
148
149
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
150
151
152
153
154
155
156
157
158
159
160
161
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

162
    return feature_dict
163

164
165
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
166

167

168
def main(args):
Christina Floristean's avatar
Christina Floristean committed
169
# Create the output directory
170
171
    os.makedirs(args.output_dir, exist_ok=True)

172
173
    if args.config_preset.startswith("seq"):
        args.use_single_seq_mode = True
174

175
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
176

177
178
179
180
181
    if(args.trace_model):
        if(not config.data.predict.fixed_size):
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
182
183

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
228
            uniref30_database_path=args.uniref30_database_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
230
231
232
233
234
235
236
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
    data_processor = data_pipeline.DataPipeline(
239
240
241
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
244
245
246
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

247
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
    random_seed = args.data_random_seed
249
    if random_seed is None:
250
        random_seed = random.randrange(2**32)
251

252
253
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
254

255
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
256
257
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
258
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
260
261
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262

263
264
    tag_list = []
    seq_list = []
265
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
267
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
268
269
        with open(fasta_path, "r") as fp:
            data = fp.read()
270

271
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
272

Christina Floristean's avatar
Christina Floristean committed
273
        if ((not is_multimer) and len(tags) != 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
275
276
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
            continue
Christina Floristean's avatar
Christina Floristean committed
279

280
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
281
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282

283
        tag_list.append((tag, tags))
284
285
286
287
288
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
289
290
291
292
293
294
295
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
    for model, output_directory in model_generator:
296
        cur_tracing_interval = 0
297
        for (tag, tags), seqs in sorted_targets:
298
299
300
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
301

302
            # Does nothing if the alignments have already been computed
Christina Floristean's avatar
Christina Floristean committed
303
            precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
304

305
306
307
308
309
310
311
312
            feature_dict = feature_dicts.get(tag, None)
            if(feature_dict is None):
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
314
                )

315
316
317
318
319
320
                if(args.trace_model):
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
321

322
323
324
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
325
                feature_dict, mode='predict', is_multimer=is_multimer
326
327
328
            )

            processed_feature_dict = {
329
                k:torch.as_tensor(v, device=args.model_device)
330
                for k,v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
            }
332

Christina Floristean's avatar
Christina Floristean committed
333
334
            if (args.trace_model):
                if (rounded_seqlen > cur_tracing_interval):
335
336
337
338
339
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
340
                    tracing_time = time.perf_counter() - t
341
                    logger.info(
342
                        f"Tracing time: {tracing_time}"
343
344
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
345

346
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347

348
            # Toss out the recycling dimensions --- we don't need them anymore
349
            processed_feature_dict = tensor_tree_map(
350
                lambda x: np.array(x[..., -1].cpu()),
351
352
                processed_feature_dict
            )
353
354
355
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
356
357
358
359
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
360
                args.config_preset,
361
362
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
363
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
364

365
366
367
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
368
            unrelaxed_output_path = os.path.join(
369
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
370
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371

372
            with open(unrelaxed_output_path, 'w') as fp:
373
374
375
376
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377

378
            logger.info(f"Output written to {unrelaxed_output_path}...")
379

380
381
            if not args.skip_relaxation:
                # Relax the prediction.
382
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
383
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
384

385
386
            if args.save_outputs:
                output_dict_path = os.path.join(
387
                    output_directory, f'{output_name}_output_dict.pkl'
388
389
390
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
391

Sam DeLuca's avatar
Sam DeLuca committed
392
                logger.info(f"Model output written to {output_dict_path}...")
393
394
395
396


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
397
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
        help="Path to directory containing FASTA files, one sequence per file"
400
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
401
402
403
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
404
405
406
407
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
    )
409
410
411
412
    parser.add_argument(
        "--use_single_seq_mode", action="store_true", default=False,
        help="""Use single sequence embeddings instead of MSAs."""
    )
413
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
415
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
416
417
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
420
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
421
422
    )
    parser.add_argument(
423
        "--config_preset", type=str, default="model_1",
424
        help="""Name of a model config preset defined in openfold/config.py"""
425
426
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427
428
429
430
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
431
432
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433
434
435
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
436
    )
437
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
438
        "--save_outputs", action="store_true", default=False,
439
440
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441
442
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
    )
445
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
446
        "--preset", type=str, default='full_dbs',
447
448
449
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450
451
452
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
453
    parser.add_argument(
454
        "--data_random_seed", type=int, default=None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
455
456
457
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
458
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459
460
461
462
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
463
464
465
466
467
468
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
469
470
471
472
473
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
474
475
476
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
477
    )
478
479
480
481
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482
    add_data_args(parser)
483
484
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
485
486
    if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
        args.jax_param_path = os.path.join(
487
            "openfold", "resources", "params",
488
            "params_" + args.config_preset + ".npz"
489
490
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
491
492
493
494
495
496
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

497
    main(args)