run_pretrained_openfold.py 17.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
20
21
22
import pickle
import random
import time
23

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import torch
29
30
31
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
32
if (
33
    torch_major_version > 1 or
34
35
36
37
38
39
40
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

41
from openfold.config import model_config
42
from openfold.data import templates, feature_pipeline, data_pipeline
43
44
45
46
47
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
                                         prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
48
49
50
51
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
52

53
from scripts.precompute_embeddings import EmbeddingGenerator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from scripts.utils import add_data_args
55

56

57
TRACING_INTERVAL = 50
58
59


60
def precompute_alignments(tags, seqs, alignment_dir, args):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64
65
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

66
        local_alignment_dir = os.path.join(
67
68
69
70
71
            alignment_dir,
            os.path.join(alignment_dir, tag),
        )

        if args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir):
72
            logger.info(f"Generating alignments for {tag}...")
73

74
            os.makedirs(local_alignment_dir)
75

76
            # In seqemb mode, use AlignmentRunner only to generate templates
77
78
79
80
81
82
            if args.use_single_seq_mode:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    no_cpus=args.cpus,
                )
83
                embedding_generator = EmbeddingGenerator()
84
                embedding_generator.run(tmp_fasta_path, alignment_dir)
85
            else:
86
87
88
89
90
91
92
93
94
95
96
97
98
                is_multimer = "multimer" in args.config_preset
                if is_multimer:
                    template_searcher = hmmsearch.Hmmsearch(
                        binary_path=args.hmmsearch_binary_path,
                        hmmbuild_binary_path=args.hmmbuild_binary_path,
                        database_path=args.pdb_seqres_database_path,
                    )
                else:
                    template_searcher = hhsearch.HHSearch(
                        binary_path=args.hhsearch_binary_path,
                        databases=[args.pdb70_database_path],
                    )

99
100
101
102
103
104
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    hhblits_binary_path=args.hhblits_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    mgnify_database_path=args.mgnify_database_path,
                    bfd_database_path=args.bfd_database_path,
105
                    uniref30_database_path=args.uniref30_database_path,
106
                    uniclust30_database_path=args.uniclust30_database_path,
107
                    uniprot_database_path=args.uniprot_database_path,
108
109
110
                    template_searcher=template_searcher,
                    use_small_bfd=args.bfd_database_path is None,
                    no_cpus=args.cpus_per_task
111
                )
112

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
            alignment_runner.run(
114
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
            )
116
117
118
119
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
121
122
123
124

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


125
126
127
128
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


129
130
131
132
133
134
135
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
136
137
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
138
        tag = tags[0]
139
140
141
142
143
144
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
145
146
147
            fasta_path=tmp_fasta_path,
            alignment_dir=local_alignment_dir,
            seqemb_mode=args.use_single_seq_mode,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
148
        )
Christina Floristean's avatar
Christina Floristean committed
149
150
151
152
153
154
155
156
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
157
158
159
160
161
162
163
164
165
166
167
168
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

169
    return feature_dict
170

171

172
173
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
174

175

176
def main(args):
177
    # Create the output directory
178
179
    os.makedirs(args.output_dir, exist_ok=True)

180
181
    if args.config_preset.startswith("seq"):
        args.use_single_seq_mode = True
182

183
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
184

185
186
    if args.trace_model:
        if not config.data.predict.fixed_size:
187
188
189
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
190
191

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
    if is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
    data_processor = data_pipeline.DataPipeline(
213
214
215
        template_featurizer=template_featurizer,
    )

216
    if is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
218
219
220
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

221
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
    random_seed = args.data_random_seed
223
    if random_seed is None:
224
        random_seed = random.randrange(2 ** 32)
225

226
227
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
228

229
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
230
231
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
232
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
234
235
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
236

237
238
    tag_list = []
    seq_list = []
239
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
241
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
        with open(fasta_path, "r") as fp:
            data = fp.read()
244

245
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

Christina Floristean's avatar
Christina Floristean committed
247
        if ((not is_multimer) and len(tags) != 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
249
250
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
            continue
Christina Floristean's avatar
Christina Floristean committed
253

254
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
255
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
        tag_list.append((tag, tags))
258
259
260
261
262
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
263
264
265
266
267
268
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
269

270
    for model, output_directory in model_generator:
271
        cur_tracing_interval = 0
272
        for (tag, tags), seqs in sorted_targets:
273
274
275
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
276

277
            # Does nothing if the alignments have already been computed
278
            precompute_alignments(tags, seqs, alignment_dir, args)
279

280
            feature_dict = feature_dicts.get(tag, None)
281
            if feature_dict is None:
282
283
284
285
286
287
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
289
                )

290
                if args.trace_model:
291
292
293
294
295
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
296

297
298
299
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
300
                feature_dict, mode='predict', is_multimer=is_multimer
301
302
303
            )

            processed_feature_dict = {
304
305
                k: torch.as_tensor(v, device=args.model_device)
                for k, v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
            }
307

308
309
            if args.trace_model:
                if rounded_seqlen > cur_tracing_interval:
310
311
312
313
314
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
315
                    tracing_time = time.perf_counter() - t
316
                    logger.info(
317
                        f"Tracing time: {tracing_time}"
318
319
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
320

321
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322

323
            # Toss out the recycling dimensions --- we don't need them anymore
324
            processed_feature_dict = tensor_tree_map(
325
                lambda x: np.array(x[..., -1].cpu()),
326
327
                processed_feature_dict
            )
328
329
330
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
331
332
333
334
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
335
                args.config_preset,
336
337
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
339

340
341
342
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
343
            unrelaxed_output_path = os.path.join(
344
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
345
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

347
            with open(unrelaxed_output_path, 'w') as fp:
348
349
350
351
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352

353
            logger.info(f"Output written to {unrelaxed_output_path}...")
354

355
356
            if not args.skip_relaxation:
                # Relax the prediction.
357
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
358
359
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
                              args.cif_output)
360

361
362
            if args.save_outputs:
                output_dict_path = os.path.join(
363
                    output_directory, f'{output_name}_output_dict.pkl'
364
365
366
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
367

Sam DeLuca's avatar
Sam DeLuca committed
368
                logger.info(f"Model output written to {output_dict_path}...")
369
370
371
372


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
373
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375
        help="Path to directory containing FASTA files, one sequence per file"
376
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377
378
379
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
380
381
382
383
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
384
    )
385
386
387
388
    parser.add_argument(
        "--use_single_seq_mode", action="store_true", default=False,
        help="""Use single sequence embeddings instead of MSAs."""
    )
389
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
390
391
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
392
393
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
397
398
    )
    parser.add_argument(
399
        "--config_preset", type=str, default="model_1",
400
        help="""Name of a model config preset defined in openfold/config.py"""
401
402
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403
404
405
406
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
407
408
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409
410
411
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
412
    )
413
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
        "--save_outputs", action="store_true", default=False,
415
416
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
417
418
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
420
    )
421
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
        "--preset", type=str, default='full_dbs',
423
424
425
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
427
428
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
429
    parser.add_argument(
430
        "--data_random_seed", type=int, default=None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
432
433
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
434
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
438
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
439
440
441
442
443
444
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
445
446
447
448
449
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
450
451
452
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
453
    )
454
455
456
457
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
458
    add_data_args(parser)
459
460
    args = parser.parse_args()

461
    if args.jax_param_path is None and args.openfold_checkpoint_path is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
        args.jax_param_path = os.path.join(
463
            "openfold", "resources", "params",
464
            "params_" + args.config_preset + ".npz"
465
466
        )

467
    if args.model_device == "cpu" and torch.cuda.is_available():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
468
469
470
471
472
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

473
    main(args)