run_pretrained_openfold.py 17.7 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
20
21
22
import pickle
import random
import time
23

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import torch
29
30
31
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
32
if (
33
    torch_major_version > 1 or
34
35
36
37
38
39
40
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

41
from openfold.config import model_config
42
from openfold.data import templates, feature_pipeline, data_pipeline
43
44
45
46
47
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
                                         prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
48
49
50
51
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
52

53
from scripts.precompute_embeddings import EmbeddingGenerator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from scripts.utils import add_data_args
55

56

57
TRACING_INTERVAL = 50
58
59


60
def precompute_alignments(tags, seqs, alignment_dir, args):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64
65
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

66
        local_alignment_dir = os.path.join(alignment_dir, tag)
67

68
        if args.use_precomputed_alignments is None:
69
            logger.info(f"Generating alignments for {tag}...")
70

71
72
73
74
75
76
77
78
79
80
81
82
83
            os.makedirs(local_alignment_dir, exist_ok=True)

            if "multimer" in args.config_preset:
                template_searcher = hmmsearch.Hmmsearch(
                    binary_path=args.hmmsearch_binary_path,
                    hmmbuild_binary_path=args.hmmbuild_binary_path,
                    database_path=args.pdb_seqres_database_path,
                )
            else:
                template_searcher = hhsearch.HHSearch(
                    binary_path=args.hhsearch_binary_path,
                    databases=[args.pdb70_database_path],
                )
84

85
            # In seqemb mode, use AlignmentRunner only to generate templates
86
87
88
89
            if args.use_single_seq_mode:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
90
                    template_searcher=template_searcher,
91
92
                    no_cpus=args.cpus,
                )
93
                embedding_generator = EmbeddingGenerator()
94
                embedding_generator.run(tmp_fasta_path, alignment_dir)
95
96
97
98
99
100
101
            else:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    hhblits_binary_path=args.hhblits_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    mgnify_database_path=args.mgnify_database_path,
                    bfd_database_path=args.bfd_database_path,
102
                    uniref30_database_path=args.uniref30_database_path,
103
                    uniclust30_database_path=args.uniclust30_database_path,
104
                    uniprot_database_path=args.uniprot_database_path,
105
106
                    template_searcher=template_searcher,
                    use_small_bfd=args.bfd_database_path is None,
107
                    no_cpus=args.cpus
108
                )
109

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
            alignment_runner.run(
111
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
            )
113
114
115
116
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
119
120
121

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


122
123
124
125
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


126
127
128
129
130
131
132
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
133
134
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
135
        tag = tags[0]
136
137
138
139
140
141
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
142
143
144
            fasta_path=tmp_fasta_path,
            alignment_dir=local_alignment_dir,
            seqemb_mode=args.use_single_seq_mode,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
        )
Christina Floristean's avatar
Christina Floristean committed
146
147
148
149
150
151
152
153
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
154
155
156
157
158
159
160
161
162
163
164
165
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

166
    return feature_dict
167

168

169
170
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
171

172

173
def main(args):
174
    # Create the output directory
175
176
    os.makedirs(args.output_dir, exist_ok=True)

177
178
    if args.config_preset.startswith("seq"):
        args.use_single_seq_mode = True
179

180
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
181

182
183
    if args.trace_model:
        if not config.data.predict.fixed_size:
184
185
186
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
187
188

    is_multimer = "multimer" in args.config_preset
rostro36's avatar
rostro36 committed
189
190
191
192
193
194
195
196
197
    is_custom_template = "use_custom_template" in args
    if is_custom_template:
        template_featurizer = templates.CustomHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date="9999-12-31", # just dummy, not used
            max_hits=-1, # just dummy, not used
            kalign_binary_path=args.kalign_binary_path
            )
    elif is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215
    data_processor = data_pipeline.DataPipeline(
216
217
        template_featurizer=template_featurizer,
    )
218
    if is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
220
221
222
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

223
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
    random_seed = args.data_random_seed
225
    if random_seed is None:
226
        random_seed = random.randrange(2 ** 32)
227

228
229
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
230
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
231
232
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
233
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
235
236
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

238
239
    tag_list = []
    seq_list = []
240
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
242
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
244
        with open(fasta_path, "r") as fp:
            data = fp.read()
245

246
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247

248
        if not is_multimer and len(tags) != 1:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
250
251
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253
            continue
Christina Floristean's avatar
Christina Floristean committed
254

255
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
256
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257

258
        tag_list.append((tag, tags))
259
260
261
262
263
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
264
265
266
267
268
269
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
270

271
    for model, output_directory in model_generator:
272
        cur_tracing_interval = 0
273
        for (tag, tags), seqs in sorted_targets:
274
275
276
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
277

278
            # Does nothing if the alignments have already been computed
279
            precompute_alignments(tags, seqs, alignment_dir, args)
280

281
            feature_dict = feature_dicts.get(tag, None)
282
            if feature_dict is None:
283
284
285
286
287
288
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
290
                )

291
                if args.trace_model:
292
293
294
295
296
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297

298
299
                feature_dicts[tag] = feature_dict
            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
300
                feature_dict, mode='predict', is_multimer=is_multimer
301
302
303
            )

            processed_feature_dict = {
304
305
                k: torch.as_tensor(v, device=args.model_device)
                for k, v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
            }
307

308
309
            if args.trace_model:
                if rounded_seqlen > cur_tracing_interval:
310
311
312
313
314
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
315
                    tracing_time = time.perf_counter() - t
316
                    logger.info(
317
                        f"Tracing time: {tracing_time}"
318
319
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
320

321
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322

323
            # Toss out the recycling dimensions --- we don't need them anymore
324
            processed_feature_dict = tensor_tree_map(
325
                lambda x: np.array(x[..., -1].cpu()),
326
327
                processed_feature_dict
            )
328
329
330
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
331
332
333
334
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
335
                args.config_preset,
336
337
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
339

340
341
342
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
343
            unrelaxed_output_path = os.path.join(
344
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
345
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

347
            with open(unrelaxed_output_path, 'w') as fp:
348
349
350
351
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352

353
            logger.info(f"Output written to {unrelaxed_output_path}...")
354

355
356
            if not args.skip_relaxation:
                # Relax the prediction.
357
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
358
359
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
                              args.cif_output)
360

361
362
            if args.save_outputs:
                output_dict_path = os.path.join(
363
                    output_directory, f'{output_name}_output_dict.pkl'
364
365
366
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
367

Sam DeLuca's avatar
Sam DeLuca committed
368
                logger.info(f"Model output written to {output_dict_path}...")
369
370
371
372


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
373
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375
        help="Path to directory containing FASTA files, one sequence per file"
376
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377
378
379
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
380
381
382
383
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
384
    )
rostro36's avatar
rostro36 committed
385
386
387
388
    parser.add_argument(
        "--use_custom_template", action="store_true", default=False,
        help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
    )
389
390
391
392
    parser.add_argument(
        "--use_single_seq_mode", action="store_true", default=False,
        help="""Use single sequence embeddings instead of MSAs."""
    )
393
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
395
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
396
397
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
401
402
    )
    parser.add_argument(
403
        "--config_preset", type=str, default="model_1",
404
        help="""Name of a model config preset defined in openfold/config.py"""
405
406
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
407
408
409
410
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
411
412
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
414
415
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
416
    )
417
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
        "--save_outputs", action="store_true", default=False,
419
420
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
422
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
    )
425
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
        "--preset", type=str, default='full_dbs',
427
428
429
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
430
431
432
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
433
    parser.add_argument(
434
        "--data_random_seed", type=int, default=None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
438
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
440
441
442
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
443
444
445
446
447
448
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
449
450
451
452
453
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
454
455
456
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
457
    )
458
459
460
461
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
    add_data_args(parser)
463
464
    args = parser.parse_args()

465
    if args.jax_param_path is None and args.openfold_checkpoint_path is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
466
        args.jax_param_path = os.path.join(
467
            "openfold", "resources", "params",
468
            "params_" + args.config_preset + ".npz"
469
470
        )

471
    if args.model_device == "cpu" and torch.cuda.is_available():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
472
473
474
475
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )
476
    main(args)