Commit 9db8dc36 authored by Jose Duarte's avatar Jose Duarte
Browse files

Adding output_cif as CLI argument

parent 4f662f83
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".") ...@@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0]) torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1]) torch_minor_version = int(torch_versions[1])
if( if(
torch_major_version > 1 or torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12) (torch_major_version == 1 and torch_minor_version >= 12)
): ):
# Gives a large speedup on Ampere-class GPUs # Gives a large speedup on Ampere-class GPUs
...@@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)): if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
...@@ -141,13 +141,13 @@ def main(args): ...@@ -141,13 +141,13 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if(args.trace_model):
if(not config.data.predict.fixed_size): if(not config.data.predict.fixed_size):
raise ValueError( raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config" "Tracing requires that fixed_size mode be enabled in the config"
) )
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -165,10 +165,10 @@ def main(args): ...@@ -165,10 +165,10 @@ def main(args):
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(2**32) random_seed = random.randrange(2**32)
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -183,7 +183,7 @@ def main(args): ...@@ -183,7 +183,7 @@ def main(args):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read() data = fp.read()
tags, seqs = parse_fasta(data) tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
...@@ -206,10 +206,10 @@ def main(args): ...@@ -206,10 +206,10 @@ def main(args):
output_name = f'{tag}_{args.config_preset}' output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None: if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed # Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None) feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None): if(feature_dict is None):
feature_dict = generate_feature_dict( feature_dict = generate_feature_dict(
...@@ -234,7 +234,7 @@ def main(args): ...@@ -234,7 +234,7 @@ def main(args):
) )
processed_feature_dict = { processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device) k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items() for k,v in processed_feature_dict.items()
} }
...@@ -255,30 +255,36 @@ def main(args): ...@@ -255,30 +255,36 @@ def main(args):
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map( processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()), lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict processed_feature_dict
) )
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output( unrelaxed_protein = prep_output(
out, out,
processed_feature_dict, processed_feature_dict,
feature_dict, feature_dict,
feature_processor, feature_processor,
args.config_preset, args.config_preset,
args.multimer_ri_gap, args.multimer_ri_gap,
args.subtract_plddt args.subtract_plddt
) )
unrelaxed_file_suffix = "_unrelaxed.pdb"
if args.cif_output:
unrelaxed_file_suffix = "_unrelaxed.cif"
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb' output_directory, f'{output_name}{unrelaxed_file_suffix}'
) )
with open(unrelaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein)) if args.cif_output:
fp.write(protein.to_modelcif(unrelaxed_protein))
else:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation: if not args.skip_relaxation:
# Relax the prediction. # Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...")
...@@ -373,12 +379,16 @@ if __name__ == "__main__": ...@@ -373,12 +379,16 @@ if __name__ == "__main__":
"--long_sequence_inference", action="store_true", default=False, "--long_sequence_inference", action="store_true", default=False,
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details""" help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
) )
parser.add_argument(
"--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)"
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None): if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join( args.jax_param_path = os.path.join(
"openfold", "resources", "params", "openfold", "resources", "params",
"params_" + args.config_preset + ".npz" "params_" + args.config_preset + ".npz"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment