run_pretrained_openfold.py 17.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
20
21
22
import pickle
import random
import time
23

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import torch
29
30
31
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
32
if (
33
    torch_major_version > 1 or
34
35
36
37
38
39
40
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

41
from openfold.config import model_config
42
from openfold.data import templates, feature_pipeline, data_pipeline
43
44
45
46
47
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
                                         prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
48
49
50
51
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
52

53
from scripts.precompute_embeddings import EmbeddingGenerator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from scripts.utils import add_data_args
55

56

57
TRACING_INTERVAL = 50
58
59


60
def precompute_alignments(tags, seqs, alignment_dir, args):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
64
65
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

66
        local_alignment_dir = os.path.join(alignment_dir, tag),
67

68
        if args.use_precomputed_alignments is None:
69
            logger.info(f"Generating alignments for {tag}...")
70

71
72
73
74
75
76
77
78
79
80
81
82
83
            os.makedirs(local_alignment_dir, exist_ok=True)

            if "multimer" in args.config_preset:
                template_searcher = hmmsearch.Hmmsearch(
                    binary_path=args.hmmsearch_binary_path,
                    hmmbuild_binary_path=args.hmmbuild_binary_path,
                    database_path=args.pdb_seqres_database_path,
                )
            else:
                template_searcher = hhsearch.HHSearch(
                    binary_path=args.hhsearch_binary_path,
                    databases=[args.pdb70_database_path],
                )
84

85
            # In seqemb mode, use AlignmentRunner only to generate templates
86
87
88
89
            if args.use_single_seq_mode:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
90
                    template_searcher=template_searcher,
91
92
                    no_cpus=args.cpus,
                )
93
                embedding_generator = EmbeddingGenerator()
94
                embedding_generator.run(tmp_fasta_path, alignment_dir)
95
96
97
98
99
100
101
            else:
                alignment_runner = data_pipeline.AlignmentRunner(
                    jackhmmer_binary_path=args.jackhmmer_binary_path,
                    hhblits_binary_path=args.hhblits_binary_path,
                    uniref90_database_path=args.uniref90_database_path,
                    mgnify_database_path=args.mgnify_database_path,
                    bfd_database_path=args.bfd_database_path,
102
                    uniref30_database_path=args.uniref30_database_path,
103
                    uniclust30_database_path=args.uniclust30_database_path,
104
                    uniprot_database_path=args.uniprot_database_path,
105
106
                    template_searcher=template_searcher,
                    use_small_bfd=args.bfd_database_path is None,
107
                    no_cpus=args.cpus
108
                )
109

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
            alignment_runner.run(
111
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
            )
113
114
115
116
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
119
120
121

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


122
123
124
125
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


126
127
128
129
130
131
132
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
133
134
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
135
        tag = tags[0]
136
137
138
139
140
141
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
142
143
144
            fasta_path=tmp_fasta_path,
            alignment_dir=local_alignment_dir,
            seqemb_mode=args.use_single_seq_mode,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
        )
Christina Floristean's avatar
Christina Floristean committed
146
147
148
149
150
151
152
153
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
154
155
156
157
158
159
160
161
162
163
164
165
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

166
    return feature_dict
167

168

169
170
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
171

172

173
def main(args):
174
    # Create the output directory
175
176
    os.makedirs(args.output_dir, exist_ok=True)

177
178
    if args.config_preset.startswith("seq"):
        args.use_single_seq_mode = True
179

180
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
181

182
183
    if args.trace_model:
        if not config.data.predict.fixed_size:
184
185
186
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
187
188

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189

190
    if is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
    data_processor = data_pipeline.DataPipeline(
210
211
212
        template_featurizer=template_featurizer,
    )

213
    if is_multimer:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214
215
216
217
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

218
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
    random_seed = args.data_random_seed
220
    if random_seed is None:
221
        random_seed = random.randrange(2 ** 32)
222

223
224
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
225

226
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
227
228
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
229
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
231
232
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233

234
235
    tag_list = []
    seq_list = []
236
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
238
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
        with open(fasta_path, "r") as fp:
            data = fp.read()
241

242
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243

244
        if not is_multimer and len(tags) != 1:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
246
247
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
            continue
Christina Floristean's avatar
Christina Floristean committed
250

251
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
252
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253

254
        tag_list.append((tag, tags))
255
256
257
258
259
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
260
261
262
263
264
265
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
266

267
    for model, output_directory in model_generator:
268
        cur_tracing_interval = 0
269
        for (tag, tags), seqs in sorted_targets:
270
271
272
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
273

274
            # Does nothing if the alignments have already been computed
275
            precompute_alignments(tags, seqs, alignment_dir, args)
276

277
            feature_dict = feature_dicts.get(tag, None)
278
            if feature_dict is None:
279
280
281
282
283
284
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
286
                )

287
                if args.trace_model:
288
289
290
291
292
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293

294
295
296
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
297
                feature_dict, mode='predict', is_multimer=is_multimer
298
299
300
            )

            processed_feature_dict = {
301
302
                k: torch.as_tensor(v, device=args.model_device)
                for k, v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
            }
304

305
306
            if args.trace_model:
                if rounded_seqlen > cur_tracing_interval:
307
308
309
310
311
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
312
                    tracing_time = time.perf_counter() - t
313
                    logger.info(
314
                        f"Tracing time: {tracing_time}"
315
316
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
317

318
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
319

320
            # Toss out the recycling dimensions --- we don't need them anymore
321
            processed_feature_dict = tensor_tree_map(
322
                lambda x: np.array(x[..., -1].cpu()),
323
324
                processed_feature_dict
            )
325
326
327
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
328
329
330
331
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
332
                args.config_preset,
333
334
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
336

337
338
339
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
340
            unrelaxed_output_path = os.path.join(
341
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
342
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343

344
            with open(unrelaxed_output_path, 'w') as fp:
345
346
347
348
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349

350
            logger.info(f"Output written to {unrelaxed_output_path}...")
351

352
353
            if not args.skip_relaxation:
                # Relax the prediction.
354
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
355
356
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
                              args.cif_output)
357

358
359
            if args.save_outputs:
                output_dict_path = os.path.join(
360
                    output_directory, f'{output_name}_output_dict.pkl'
361
362
363
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
364

Sam DeLuca's avatar
Sam DeLuca committed
365
                logger.info(f"Model output written to {output_dict_path}...")
366
367
368
369


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
370
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
372
        help="Path to directory containing FASTA files, one sequence per file"
373
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
375
376
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
377
378
379
380
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
    )
382
383
384
385
    parser.add_argument(
        "--use_single_seq_mode", action="store_true", default=False,
        help="""Use single sequence embeddings instead of MSAs."""
    )
386
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387
388
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
389
390
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
391
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
393
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
394
395
    )
    parser.add_argument(
396
        "--config_preset", type=str, default="model_1",
397
        help="""Name of a model config preset defined in openfold/config.py"""
398
399
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
401
402
403
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
404
405
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
406
407
408
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
409
    )
410
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411
        "--save_outputs", action="store_true", default=False,
412
413
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
415
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
417
    )
418
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
        "--preset", type=str, default='full_dbs',
420
421
422
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
424
425
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
426
    parser.add_argument(
427
        "--data_random_seed", type=int, default=None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
428
429
430
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
431
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
433
434
435
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
436
437
438
439
440
441
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
442
443
444
445
446
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
447
448
449
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
450
    )
451
452
453
454
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
455
    add_data_args(parser)
456
457
    args = parser.parse_args()

458
    if args.jax_param_path is None and args.openfold_checkpoint_path is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459
        args.jax_param_path = os.path.join(
460
            "openfold", "resources", "params",
461
            "params_" + args.config_preset + ".npz"
462
463
        )

464
    if args.model_device == "cpu" and torch.cuda.is_available():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
465
466
467
468
469
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

470
    main(args)