run_pretrained_openfold.py 17.5 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
from datetime import date
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
import gc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import logging
20
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
import os
Sam DeLuca's avatar
Sam DeLuca committed
22
from copy import deepcopy
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
23

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
import pickle
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
26
27
from pytorch_lightning.utilities.deepspeed import (
    convert_zero_checkpoint_to_fp32_state_dict
)
28
29
import random
import sys
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
31
import time
import torch
32
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33

34
from openfold.config import model_config
35
from openfold.data import templates, feature_pipeline, data_pipeline
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
from openfold.model.model import AlphaFold
37
from openfold.model.torchscript import script_preset_
38
from openfold.np import residue_constants, protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
40
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
    import_jax_weights_,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
46
    tensor_tree_map,
)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
from scripts.utils import add_data_args
48

49

50
51
52
53
54
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
57
58
59
60
61
62
def precompute_alignments(tags, seqs, alignment_dir, args):
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        if(args.use_precomputed_alignments is None):
63
            logger.info(f"Generating alignments for {tag}...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64
65
            if not os.path.exists(local_alignment_dir):
                os.makedirs(local_alignment_dir)
66

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
68
69
70
71
72
73
74
75
76
77
78
            alignment_runner = data_pipeline.AlignmentRunner(
                jackhmmer_binary_path=args.jackhmmer_binary_path,
                hhblits_binary_path=args.hhblits_binary_path,
                hhsearch_binary_path=args.hhsearch_binary_path,
                uniref90_database_path=args.uniref90_database_path,
                mgnify_database_path=args.mgnify_database_path,
                bfd_database_path=args.bfd_database_path,
                uniclust30_database_path=args.uniclust30_database_path,
                pdb70_database_path=args.pdb70_database_path,
                no_cpus=args.cpus,
            )
            alignment_runner.run(
79
                tmp_fasta_path, local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            )

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


def run_model(model, batch, tag, args):
    with torch.no_grad():
        batch = {
            k:torch.as_tensor(v, device=args.model_device) 
            for k,v in batch.items()
        }
 
        # Disable templates if there aren't any in the batch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
        model.config.template.enabled = model.config.template.enabled and any([
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
97
            "template_" in k for k in batch
        ])

98
        logger.info(f"Running inference for {tag}...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
99
100
        t = time.perf_counter()
        out = model(batch)
101
        logger.info(f"Inference time: {time.perf_counter() - t}")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    
    return out


def prep_output(out, batch, feature_dict, feature_processor, args):
    plddt = out["plddt"]
    mean_plddt = np.mean(plddt)
    
    plddt_b_factors = np.repeat(
        plddt[..., None], residue_constants.atom_type_num, axis=-1
    )

    # Prep protein metadata
    template_domain_names = []
    template_chain_index = None
117
    if(feature_processor.config.common.use_templates and "template_domain_names" in feature_dict):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        template_domain_names = [
            t.decode("utf-8") for t in feature_dict["template_domain_names"]
        ]

        # This works because templates are not shuffled during inference
        template_domain_names = template_domain_names[
            :feature_processor.config.predict.max_templates
        ]

        if("template_chain_index" in feature_dict):
            template_chain_index = feature_dict["template_chain_index"]
            template_chain_index = template_chain_index[
                :feature_processor.config.predict.max_templates
            ]

    no_recycling = feature_processor.config.common.max_recycling_iters
    remark = ', '.join([
        f"no_recycling={no_recycling}",
        f"max_templates={feature_processor.config.predict.max_templates}",
137
        f"config_preset={args.config_preset}",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    ])

    # For multi-chain FASTAs
    ri = feature_dict["residue_index"]
    chain_index = (ri - np.arange(ri.shape[0])) / args.multimer_ri_gap
    chain_index = chain_index.astype(np.int64)
    cur_chain = 0
    prev_chain_max = 0
    for i, c in enumerate(chain_index):
        if(c != cur_chain):
            cur_chain = c
            prev_chain_max = i + cur_chain * args.multimer_ri_gap

        batch["residue_index"][i] -= prev_chain_max

    unrelaxed_protein = protein.from_prediction(
        features=batch,
        result=out,
        b_factors=plddt_b_factors,
        chain_index=chain_index,
        remark=remark,
        parents=template_domain_names,
        parents_chain_index=template_chain_index,
    )

    return unrelaxed_protein


166
def parse_fasta(data):
167
    data = re.sub('>$', '', data, flags=re.M)
168
    lines = [
169
170
171
        l.replace('\n', '')
        for prot in data.split('>') for l in prot.strip().split('\n', 1)
    ][1:]
172
    tags, seqs = lines[::2], lines[1::2]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173

174
    tags = [t.split()[0] for t in tags]
175

176
    return tags, seqs
177

178

179
180
181
182
183
184
185
def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
186
187
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
188
        tag = tags[0]
189
190
191
192
193
194
195
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196
        )
197
198
199
200
201
202
203
204
205
206
207
208
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

209
    return feature_dict
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def get_model_basename(model_path):
    return os.path.splitext(
                os.path.basename(
                    os.path.normpath(model_path)
                )
            )[0]

def make_output_directory(output_dir, model_name, multiple_model_mode):
    if multiple_model_mode:
        prediction_dir = os.path.join(output_dir, "predictions", model_name)
    else:
        prediction_dir = os.path.join(output_dir, "predictions")
    os.makedirs(prediction_dir, exist_ok=True)
    return prediction_dir

def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
    model_count = 0
    if openfold_checkpoint_path:
        model_count += len(openfold_checkpoint_path.split(","))
    if jax_param_path:
        model_count += len(jax_param_path.split(","))
    return model_count
233
234
235

def load_models_from_command_line(args, config):
    # Create the output directory
236
237
238
239
240

    multiple_model_mode = count_models_to_evaluate(args.openfold_checkpoint_path, args.jax_param_path) > 1
    if multiple_model_mode:
        logger.info(f"evaluating multiple models")

241
242
    if args.jax_param_path:
        for path in args.jax_param_path.split(","):
243
244
            model_basename = get_model_basename(path)
            model_version = "_".join(model_basename.split("_")[1:])
245
246
247
            model = AlphaFold(config)
            model = model.eval()
            import_jax_weights_(
248
                model, path, version=model_version
249
250
            )
            model = model.to(args.model_device)
251
            logger.info(
252
                f"Successfully loaded JAX parameters at {path}..."
253
            )
254
255
            output_directory = make_output_directory(args.output_dir, model_basename, multiple_model_mode)
            yield model, output_directory
256
    
257
    if args.openfold_checkpoint_path:
Sam DeLuca's avatar
wip  
Sam DeLuca committed
258
        for path in args.openfold_checkpoint_path.split(","):
259
260
            model = AlphaFold(config)
            model = model.eval()
261
            checkpoint_basename = get_model_basename(path)
262
            if os.path.isdir(path):
263
                # A DeepSpeed checkpoint
264
265
266
                ckpt_path = os.path.join(
                    args.output_dir,
                    checkpoint_basename + ".pt",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
                )

269
270
                if not os.path.isfile(ckpt_path):
                    convert_zero_checkpoint_to_fp32_state_dict(
Sam DeLuca's avatar
wip  
Sam DeLuca committed
271
                        path,
272
273
                        ckpt_path,
                    )
274
275
                d = torch.load(ckpt_path)
                model.load_state_dict(d["ema"]["params"])
276
277
            else:
                ckpt_path = path
278
                d = torch.load(ckpt_path)
279

280
                if "ema" in d:
281
282
283
                    # The public weights have had this done to them already
                    d = d["ema"]["params"]
                model.load_state_dict(d)
284
            
285
            model = model.to(args.model_device)
286
            logger.info(
287
                f"Loaded OpenFold parameters at {path}..."
288
            )
289
290
            output_directory = make_output_directory(args.output_dir, checkpoint_basename, multiple_model_mode)
            yield model, output_directory
291
    
Sam DeLuca's avatar
wip  
Sam DeLuca committed
292
    if not args.jax_param_path and not args.openfold_checkpoint_path:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
294
295
296
297
        raise ValueError(
            "At least one of jax_param_path or openfold_checkpoint_path must "
            "be specified."
        )

298
299
def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]
300
301
302
303
304

def main(args):
    # Create the output directory
    os.makedirs(args.output_dir, exist_ok=True)

305
    config = model_config(args.config_preset)
306
307
308
    template_featurizer = templates.TemplateHitFeaturizer(
        mmcif_dir=args.template_mmcif_dir,
        max_template_date=args.max_template_date,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
        max_hits=config.data.predict.max_templates,
310
        kalign_binary_path=args.kalign_binary_path,
311
        release_dates_path=args.release_dates_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312
313
        obsolete_pdbs_path=args.obsolete_pdbs_path
    )
314

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
315
    data_processor = data_pipeline.DataPipeline(
316
317
318
319
        template_featurizer=template_featurizer,
    )

    output_dir_base = args.output_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
    random_seed = args.data_random_seed
321
322
    if random_seed is None:
        random_seed = random.randrange(sys.maxsize)
323
    feature_processor = feature_pipeline.FeaturePipeline(config.data)
324
325
    if not os.path.exists(output_dir_base):
        os.makedirs(output_dir_base)
326
    if args.use_precomputed_alignments is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
327
        alignment_dir = os.path.join(output_dir_base, "alignments")
Gustaf's avatar
Gustaf committed
328
329
    else:
        alignment_dir = args.use_precomputed_alignments
330
        logger.info(f"Using precomputed alignments at {alignment_dir}...")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331

332
    for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
333
        # Gather input sequences
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
        with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
335
336
337
338
            data = fp.read()
    
        tags, seqs = parse_fasta(data)
        # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
339
340
341
342
343
        tag = '-'.join(tags)
    
        output_name = f'{tag}_{args.config_preset}'
        if args.output_postfix is not None:
            output_name = f'{output_name}_{args.output_postfix}'
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
344

345
346
347
348
349
        precompute_alignments(tags, seqs, alignment_dir, args)
    
        feature_dict = generate_feature_dict(
            tags,
            seqs,
350
351
            alignment_dir,
            data_processor,
352
353
            args,
        )
Sam DeLuca's avatar
Sam DeLuca committed
354

355
356
357
        processed_feature_dict = feature_processor.process_features(
            feature_dict, mode='predict',
        )
Sam DeLuca's avatar
Sam DeLuca committed
358

359
        for model, output_directory in load_models_from_command_line(args, config):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
            working_batch = deepcopy(processed_feature_dict)
Sam DeLuca's avatar
Sam DeLuca committed
361
            out = run_model(model, working_batch, tag, args)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362

363
            # Toss out the recycling dimensions --- we don't need them anymore
Sam DeLuca's avatar
Sam DeLuca committed
364
            working_batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), working_batch)
365
366
367
            out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

            unrelaxed_protein = prep_output(
Sam DeLuca's avatar
Sam DeLuca committed
368
                out, working_batch, feature_dict, feature_processor, args
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
369
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
370

371
372
373
374
375
376
377
378
            unrelaxed_output_path = os.path.join(
                output_directory, f'{output_name}_unrelaxed.pdb'
            )

            # Output already exists
            if os.path.exists(unrelaxed_output_path):
                continue

379
380
            with open(unrelaxed_output_path, 'w') as fp:
                fp.write(protein.to_pdb(unrelaxed_protein))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381

382
            logger.info(f"Output written to {unrelaxed_output_path}...")
383
            
384
385
386
387
388
            if not args.skip_relaxation:
                amber_relaxer = relax.AmberRelaxation(
                    use_gpu=(args.model_device != "cpu"),
                    **config.relax,
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389

390
                # Relax the prediction.
391
                logger.info(f"Running relaxation on {unrelaxed_output_path}...")
392
393
394
395
396
397
398
                t = time.perf_counter()
                visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
                if "cuda" in args.model_device:
                    device_no = args.model_device.split(":")[-1]
                    os.environ["CUDA_VISIBLE_DEVICES"] = device_no
                relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
                os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
399
                logger.info(f"Relaxation time: {time.perf_counter() - t}")
400
401
402

                # Save the relaxed PDB.
                relaxed_output_path = os.path.join(
403
                    output_directory, f'{output_name}_relaxed.pdb'
404
405
406
                )
                with open(relaxed_output_path, 'w') as fp:
                    fp.write(relaxed_pdb_str)
407
                
408
                logger.info(f"Relaxed output written to {relaxed_output_path}...")
409

410
411
            if args.save_outputs:
                output_dict_path = os.path.join(
412
                    output_directory, f'{output_name}_output_dict.pkl'
413
414
415
                )
                with open(output_dict_path, "wb") as fp:
                    pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
416

Sam DeLuca's avatar
Sam DeLuca committed
417
                logger.info(f"Model output written to {output_dict_path}...")
418

419

420
421
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
422
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
424
        "fasta_dir", type=str,
        help="Path to directory containing FASTA files, one sequence per file"
425
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
427
428
    parser.add_argument(
        "template_mmcif_dir", type=str,
    )
Gustaf's avatar
Gustaf committed
429
430
431
432
    parser.add_argument(
        "--use_precomputed_alignments", type=str, default=None,
        help="""Path to alignment directory. If provided, alignment computation 
                is skipped and database path arguments are ignored."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433
    )
434
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
        "--output_dir", type=str, default=os.getcwd(),
        help="""Name of the directory in which to output the prediction""",
437
438
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439
        "--model_device", type=str, default="cpu",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
440
441
        help="""Name of the device on which to run the model. Any valid torch
             device name is accepted (e.g. "cpu", "cuda:0")"""
442
443
    )
    parser.add_argument(
444
        "--config_preset", type=str, default="model_1",
445
        help="""Name of a model config preset defined in openfold/config.py"""
446
447
    )
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
450
451
452
453
454
455
456
        "--jax_param_path", type=str, default=None,
        help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
             is also None, parameters are selected automatically according to 
             the model name from openfold/resources/params"""
    )
    parser.add_argument(
        "--openfold_checkpoint_path", type=str, default=None,
        help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
             checkpoint directory or a .pt file"""
457
    )
458
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459
        "--save_outputs", action="store_true", default=False,
460
461
        help="Whether to save all model outputs, including embeddings, etc."
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
463
    parser.add_argument(
        "--cpus", type=int, default=4,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
464
        help="""Number of CPUs with which to run alignment tools"""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
465
    )
466
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
467
        "--preset", type=str, default='full_dbs',
468
469
        choices=('reduced_dbs', 'full_dbs')
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
470
471
472
473
    parser.add_argument(
        "--output_postfix", type=str, default=None,
        help="""Postfix for output prediction filenames"""
    )
474
    parser.add_argument(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
476
477
478
        "--data_random_seed", type=str, default=None
    )
    parser.add_argument(
        "--skip_relaxation", action="store_true", default=False,
479
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
480
481
482
483
    parser.add_argument(
        "--multimer_ri_gap", type=int, default=200,
        help="""Residue index offset between multiple sequences, if provided"""
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
484
    add_data_args(parser)
485
486
    args = parser.parse_args()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
488
    if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
        args.jax_param_path = os.path.join(
489
            "openfold", "resources", "params", 
490
            "params_" + args.config_preset + ".npz"
491
492
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
493
494
495
496
497
498
    if(args.model_device == "cpu" and torch.cuda.is_available()):
        logging.warning(
            """The model is being run on CPU. Consider specifying 
            --model_device for better performance"""
        )

499
    main(args)