run_pretrained_openfold.py 17 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20

21
22
23
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
    update_timings, relax_protein

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import pickle
29

30
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33
import time
import torch

34
35
36
37
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
38
    torch_major_version > 1 or
39
40
41
42
43
44
45
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

46
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
from openfold.model.model import AlphaFold
49
from openfold.model.torchscript import script_preset_
50
from openfold.data import templates, feature_pipeline, data_pipeline
51
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
import openfold.np.relax.relax as relax
53

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
    tensor_tree_map,
)
57
58
59
60
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
from scripts.utils import add_data_args
62

63

64
TRACING_INTERVAL = 50
65
66


Christina Floristean's avatar
Christina Floristean committed
67
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
69
70
71
72
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

Christina Floristean's avatar
Christina Floristean committed
73
74
75
76
77
78
79
        if is_multimer:
            local_alignment_dir = alignment_dir
        else:
            local_alignment_dir = os.path.join(
                alignment_dir,
                os.path.join(alignment_dir, tag),
            )
80
        if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
81
            logger.info(f"Generating alignments for {tag}...")
82

83
            os.makedirs(local_alignment_dir)
84

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
91
92
93
94
            alignment_runner = data_pipeline.AlignmentRunner(
                jackhmmer_binary_path=args.jackhmmer_binary_path,
                hhblits_binary_path=args.hhblits_binary_path,
                uniref90_database_path=args.uniref90_database_path,
                mgnify_database_path=args.mgnify_database_path,
                bfd_database_path=args.bfd_database_path,
                uniclust30_database_path=args.uniclust30_database_path,
                no_cpus=args.cpus,
            )
            alignment_runner.run(
95
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
            )
97
98
99
100
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


106
107
108
109
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


110
111
112
113
114
115
116
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
117
118
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
119
        tag = tags[0]
120
121
122
123
124
125
126
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
        )
Christina Floristean's avatar
Christina Floristean committed
128
129
130
131
132
133
134
135
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
136
137
138
139
140
141
142
143
144
145
146
147
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

148
    return feature_dict
149

150
151
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
152

153

154
def main(args):
Christina Floristean's avatar
Christina Floristean committed
155
# Create the output directory
156
157
    os.makedirs(args.output_dir, exist_ok=True)

158
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
159

160
161
162
163
164
    if(args.trace_model):
        if(not config.data.predict.fixed_size):
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
165
166

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
    data_processor = data_pipeline.DataPipeline(
221
222
223
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
225
226
227
228
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

229
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
    random_seed = args.data_random_seed
231
    if random_seed is None:
232
        random_seed = random.randrange(2**32)
233

234
235
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
236

237
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
238
239
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
240
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
242
243
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244

245
246
    tag_list = []
    seq_list = []
247
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
249
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
251
        with open(fasta_path, "r") as fp:
            data = fp.read()
252

253
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

Christina Floristean's avatar
Christina Floristean committed
255
        if ((not is_multimer) and len(tags) != 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
257
258
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260
            continue
Christina Floristean's avatar
Christina Floristean committed
261

262
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
263
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264

265
        tag_list.append((tag, tags))
266
267
268
269
270
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
271
272
273
274
275
276
277
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
    for model, output_directory in model_generator:
278
        cur_tracing_interval = 0
279
        for (tag, tags), seqs in sorted_targets:
280
281
282
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
283

284
            # Does nothing if the alignments have already been computed
Christina Floristean's avatar
Christina Floristean committed
285
            precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
286

287
288
289
290
291
292
293
294
            feature_dict = feature_dicts.get(tag, None)
            if(feature_dict is None):
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296
                )

297
298
299
300
301
302
                if(args.trace_model):
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303

304
305
306
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
307
                feature_dict, mode='predict', is_multimer=is_multimer
308
309
310
            )

            processed_feature_dict = {
311
                k:torch.as_tensor(v, device=args.model_device)
312
                for k,v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
            }
314

Christina Floristean's avatar
Christina Floristean committed
315
316
            if (args.trace_model):
                if (rounded_seqlen > cur_tracing_interval):
317
318
319
320
321
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
322
                    tracing_time = time.perf_counter() - t
323
                    logger.info(
324
                        f"Tracing time: {tracing_time}"
325
326
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
327

328
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329

330
            # Toss out the recycling dimensions --- we don't need them anymore
331
            processed_feature_dict = tensor_tree_map(
332
                lambda x: np.array(x[..., -1].cpu()),
333
334
                processed_feature_dict
            )
335
336
337
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
338
339
340
341
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
342
                args.config_preset,
343
344
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

347
348
349
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
350
            unrelaxed_output_path = os.path.join(
351
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
352
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353

354
            with open(unrelaxed_output_path, 'w') as fp:
355
356
357
358
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359

360
            logger.info(f"Output written to {unrelaxed_output_path}...")
361

362
363
            if not args.skip_relaxation:
                # Relax the prediction.
364
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
365
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
366

367
368
            if args.save_outputs:
                output_dict_path = os.path.join(
369
                    output_directory, f'{output_name}_output_dict.pkl'
370
371
372
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
373

Sam DeLuca's avatar
Sam DeLuca committed
374
                logger.info(f"Model output written to {output_dict_path}...")
375
376
377
378


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
379
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
        help="Path to directory containing FASTA files, one sequence per file"
382
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
385
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
386
387
388
389
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
    )
391
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
393
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
394
395
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
396
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
397
398
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
399
400
    )
    parser.add_argument(
401
        "--config_preset", type=str, default="model_1",
402
        help="""Name of a model config preset defined in openfold/config.py"""
403
404
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
406
407
408
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
409
410
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411
412
413
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
414
    )
415
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
        "--save_outputs", action="store_true", default=False,
417
418
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
420
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
    )
423
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
        "--preset", type=str, default='full_dbs',
425
426
427
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
428
429
430
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
431
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
433
434
435
        "--data_random_seed", type=str, default=None
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
436
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
438
439
440
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
441
442
443
444
445
446
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
447
448
449
450
451
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
452
453
454
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
455
    )
456
457
458
459
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
460
    add_data_args(parser)
461
462
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
463
464
    if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
        args.jax_param_path = os.path.join(
465
            "openfold", "resources", "params",
466
            "params_" + args.config_preset + ".npz"
467
468
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
469
470
471
472
473
474
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

475
    main(args)