"googlemock/include/vscode:/vscode.git/clone" did not exist on "646603961bf88f8b07e5da952612cfe1e0e020fa"
run_pretrained_openfold.py 17.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
import logging
17
import math
18
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20

21
22
23
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
    update_timings, relax_protein

24
25
26
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
import pickle
29

30
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33
import time
import torch

34
35
36
37
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
38
    torch_major_version > 1 or
39
40
41
42
43
44
45
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

46
from openfold.config import model_config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
from openfold.data.tools import hhsearch, hmmsearch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
from openfold.model.model import AlphaFold
49
from openfold.model.torchscript import script_preset_
50
from openfold.data import templates, feature_pipeline, data_pipeline
51
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
import openfold.np.relax.relax as relax
53

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
    tensor_tree_map,
)
57
58
59
60
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
from scripts.utils import add_data_args
62

63

64
TRACING_INTERVAL = 50
65
66


Christina Floristean's avatar
Christina Floristean committed
67
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
69
70
71
72
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

Christina Floristean's avatar
Christina Floristean committed
73
74
75
76
77
78
79
        if is_multimer:
            local_alignment_dir = alignment_dir
        else:
            local_alignment_dir = os.path.join(
                alignment_dir,
                os.path.join(alignment_dir, tag),
            )
80
        if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
81
            logger.info(f"Generating alignments for {tag}...")
82

83
            os.makedirs(local_alignment_dir)
84

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
            alignment_runner = data_pipeline.AlignmentRunner(
                jackhmmer_binary_path=args.jackhmmer_binary_path,
                hhblits_binary_path=args.hhblits_binary_path,
                uniref90_database_path=args.uniref90_database_path,
                mgnify_database_path=args.mgnify_database_path,
                bfd_database_path=args.bfd_database_path,
91
                uniref30_database_path=args.uniref30_database_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93
94
95
                uniclust30_database_path=args.uniclust30_database_path,
                no_cpus=args.cpus,
            )
            alignment_runner.run(
96
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97
            )
98
99
100
101
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
106

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


107
108
109
110
def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


111
112
113
114
115
116
117
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
118
119
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
120
        tag = tags[0]
121
122
123
124
125
126
127
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
128
        )
Christina Floristean's avatar
Christina Floristean committed
129
130
131
132
133
134
135
136
    elif "multimer" in args.config_preset:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
        )
137
138
139
140
141
142
143
144
145
146
147
148
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

149
    return feature_dict
150

151
152
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
153

154

155
def main(args):
Christina Floristean's avatar
Christina Floristean committed
156
# Create the output directory
157
158
    os.makedirs(args.output_dir, exist_ok=True)

159
    config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
160

161
162
163
164
165
    if(args.trace_model):
        if(not config.data.predict.fixed_size):
            raise ValueError(
                "Tracing requires that fixed_size mode be enabled in the config"
            )
Christina Floristean's avatar
Christina Floristean committed
166
167

    is_multimer = "multimer" in args.config_preset
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    if(is_multimer):
        if(not args.use_precomputed_alignments):
            template_searcher = hmmsearch.Hmmsearch(
                binary_path=args.hmmsearch_binary_path,
                hmmbuild_binary_path=args.hmmbuild_binary_path,
                database_path=args.pdb_seqres_database_path,
            )
        else:
            template_searcher = None

        template_featurizer = templates.HmmsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )
    else:
        if(not args.use_precomputed_alignments):
            template_searcher = hhsearch.HHSearch(
                binary_path=args.hhsearch_binary_path,
                databases=[args.pdb70_database_path],
            )
        else:
            template_searcher = None

        template_featurizer = templates.HhsearchHitFeaturizer(
            mmcif_dir=args.template_mmcif_dir,
            max_template_date=args.max_template_date,
            max_hits=config.data.predict.max_templates,
            kalign_binary_path=args.kalign_binary_path,
            release_dates_path=args.release_dates_path,
            obsolete_pdbs_path=args.obsolete_pdbs_path
        )

    if(not args.use_precomputed_alignments):
        alignment_runner = data_pipeline.AlignmentRunner(
            jackhmmer_binary_path=args.jackhmmer_binary_path,
            hhblits_binary_path=args.hhblits_binary_path,
            uniref90_database_path=args.uniref90_database_path,
            mgnify_database_path=args.mgnify_database_path,
            bfd_database_path=args.bfd_database_path,
212
            uniref30_database_path=args.uniref30_database_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
214
215
216
217
218
219
220
            uniclust30_database_path=args.uniclust30_database_path,
            uniprot_database_path=args.uniprot_database_path,
            template_searcher=template_searcher,
            use_small_bfd=(args.bfd_database_path is None),
            no_cpus=args.cpus,
        )
    else:
        alignment_runner = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
    data_processor = data_pipeline.DataPipeline(
223
224
225
        template_featurizer=template_featurizer,
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
228
229
230
    if(is_multimer):
        data_processor = data_pipeline.DataPipelineMultimer(
            monomer_data_pipeline=data_processor,
        )

231
    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
    random_seed = args.data_random_seed
233
    if random_seed is None:
234
        random_seed = random.randrange(2**32)
235

236
237
    np.random.seed(random_seed)
    torch.manual_seed(random_seed + 1)
238

239
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
240
241
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
242
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
244
245
    else:
        alignment_dir = args.use_precomputed_alignments
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

247
248
    tag_list = []
    seq_list = []
249
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
        # Gather input sequences
Christina Floristean's avatar
Christina Floristean committed
251
        fasta_path = os.path.join(args.fasta_dir, fasta_file)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
253
        with open(fasta_path, "r") as fp:
            data = fp.read()
254

255
        tags, seqs = parse_fasta(data)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

Christina Floristean's avatar
Christina Floristean committed
257
        if ((not is_multimer) and len(tags) != 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258
259
260
            print(
                f"{fasta_path} contains more than one sequence but "
                f"multimer mode is not enabled. Skipping..."
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262
            continue
Christina Floristean's avatar
Christina Floristean committed
263

264
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
265
        tag = '-'.join(tags)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266

267
        tag_list.append((tag, tags))
268
269
270
271
272
        seq_list.append(seqs)

    seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
    sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
    feature_dicts = {}
273
274
275
276
277
278
279
    model_generator = load_models_from_command_line(
        config,
        args.model_device,
        args.openfold_checkpoint_path,
        args.jax_param_path,
        args.output_dir)
    for model, output_directory in model_generator:
280
        cur_tracing_interval = 0
281
        for (tag, tags), seqs in sorted_targets:
282
283
284
            output_name = f'{tag}_{args.config_preset}'
            if args.output_postfix is not None:
                output_name = f'{output_name}_{args.output_postfix}'
285

286
            # Does nothing if the alignments have already been computed
Christina Floristean's avatar
Christina Floristean committed
287
            precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
288

289
290
291
292
293
294
295
296
            feature_dict = feature_dicts.get(tag, None)
            if(feature_dict is None):
                feature_dict = generate_feature_dict(
                    tags,
                    seqs,
                    alignment_dir,
                    data_processor,
                    args,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297
298
                )

299
300
301
302
303
304
                if(args.trace_model):
                    n = feature_dict["aatype"].shape[-2]
                    rounded_seqlen = round_up_seqlen(n)
                    feature_dict = pad_feature_dict_seq(
                        feature_dict, rounded_seqlen,
                    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305

306
307
308
                feature_dicts[tag] = feature_dict

            processed_feature_dict = feature_processor.process_features(
Christina Floristean's avatar
Christina Floristean committed
309
                feature_dict, mode='predict', is_multimer=is_multimer
310
311
312
            )

            processed_feature_dict = {
313
                k:torch.as_tensor(v, device=args.model_device)
314
                for k,v in processed_feature_dict.items()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
315
            }
316

Christina Floristean's avatar
Christina Floristean committed
317
318
            if (args.trace_model):
                if (rounded_seqlen > cur_tracing_interval):
319
320
321
322
323
                    logger.info(
                        f"Tracing model at {rounded_seqlen} residues..."
                    )
                    t = time.perf_counter()
                    trace_model_(model, processed_feature_dict)
324
                    tracing_time = time.perf_counter() - t
325
                    logger.info(
326
                        f"Tracing time: {tracing_time}"
327
328
                    )
                    cur_tracing_interval = rounded_seqlen
Sam DeLuca's avatar
Sam DeLuca committed
329

330
            out = run_model(model, processed_feature_dict, tag, args.output_dir)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331

332
            # Toss out the recycling dimensions --- we don't need them anymore
333
            processed_feature_dict = tensor_tree_map(
334
                lambda x: np.array(x[..., -1].cpu()),
335
336
                processed_feature_dict
            )
337
338
339
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
340
341
342
343
                out,
                processed_feature_dict,
                feature_dict,
                feature_processor,
344
                args.config_preset,
345
346
                args.multimer_ri_gap,
                args.subtract_plddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348

349
350
351
            unrelaxed_file_suffix = "_unrelaxed.pdb"
            if args.cif_output:
                unrelaxed_file_suffix = "_unrelaxed.cif"
352
            unrelaxed_output_path = os.path.join(
353
                output_directory, f'{output_name}{unrelaxed_file_suffix}'
354
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355

356
            with open(unrelaxed_output_path, 'w') as fp:
357
358
359
360
                if args.cif_output:
                    fp.write(protein.to_modelcif(unrelaxed_protein))
                else:
                    fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361

362
            logger.info(f"Output written to {unrelaxed_output_path}...")
363

364
365
            if not args.skip_relaxation:
                # Relax the prediction.
366
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
367
                relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
368

369
370
            if args.save_outputs:
                output_dict_path = os.path.join(
371
                    output_directory, f'{output_name}_output_dict.pkl'
372
373
374
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
375

Sam DeLuca's avatar
Sam DeLuca committed
376
                logger.info(f"Model output written to {output_dict_path}...")
377
378
379
380


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
381
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
        "fasta_dir", type=str,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
        help="Path to directory containing FASTA files, one sequence per file"
384
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
386
387
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
388
389
390
391
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
    )
393
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
395
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
396
397
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
401
402
    )
    parser.add_argument(
403
        "--config_preset", type=str, default="model_1",
404
        help="""Name of a model config preset defined in openfold/config.py"""
405
406
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
407
408
409
410
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
411
412
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
414
415
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
416
    )
417
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
        "--save_outputs", action="store_true", default=False,
419
420
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
422
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
    )
425
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
        "--preset", type=str, default='full_dbs',
427
428
429
        choices=('reduced_dbs', 'full_dbs')
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
430
431
432
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
433
    parser.add_argument(
434
        "--data_random_seed", type=int, default=None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
438
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
440
441
442
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
443
444
445
446
447
448
    parser.add_argument(
        "--trace_model", action="store_true", default=False,
        help="""Whether to convert parts of each model to TorchScript.
                Significantly improves runtime at the cost of lengthy
                'compilation.' Useful for large batch jobs."""
    )
449
450
451
452
453
    parser.add_argument(
        "--subtract_plddt", action="store_true", default=False,
        help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                 of the pLDDT itself"""
    )
454
455
456
    parser.add_argument(
        "--long_sequence_inference", action="store_true", default=False,
        help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
457
    )
458
459
460
461
    parser.add_argument(
        "--cif_output", action="store_true", default=False,
        help="Output predicted models in ModelCIF format instead of PDB format (default)"
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
    add_data_args(parser)
463
464
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
465
466
    if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
        args.jax_param_path = os.path.join(
467
            "openfold", "resources", "params",
468
            "params_" + args.config_preset + ".npz"
469
470
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471
472
473
474
475
476
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

477
    main(args)