Unverified Commit 296cd7c6 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #287 from josemduarte/modelcif_output

New option to output in ModelCIF format instead of PDB format
parents 685e8b5f 03f3a7f5
...@@ -27,4 +27,5 @@ dependencies: ...@@ -27,4 +27,5 @@ dependencies:
- typing-extensions==3.10.0.2 - typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10 - pytorch_lightning==1.5.10
- wandb==0.12.21 - wandb==0.12.21
- modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
...@@ -23,6 +23,13 @@ import string ...@@ -23,6 +23,13 @@ import string
from openfold.np import residue_constants from openfold.np import residue_constants
from Bio.PDB import PDBParser from Bio.PDB import PDBParser
import numpy as np import numpy as np
import modelcif
import modelcif.model
import modelcif.dumper
import modelcif.reference
import modelcif.protocol
import modelcif.alignment
import modelcif.qa_metric
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
...@@ -56,7 +63,7 @@ class Protein: ...@@ -56,7 +63,7 @@ class Protein:
# Chain indices for multi-chain predictions # Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB # Optional remark about the protein. Included as a comment in output PDB
# files # files
remark: Optional[str] = None remark: Optional[str] = None
...@@ -75,8 +82,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -75,8 +82,7 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
Args: Args:
pdb_str: The contents of the pdb file pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed. is parsed.
Returns: Returns:
...@@ -171,7 +177,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -171,7 +177,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0 tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
] ]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C'] atoms = ['N', 'CA', 'C']
aatype = None aatype = None
atom_positions = None atom_positions = None
...@@ -246,7 +252,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str: ...@@ -246,7 +252,7 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
""" """
out_pdb_lines = [] out_pdb_lines = []
lines = pdb_str.split('\n') lines = pdb_str.split('\n')
remark = prot.remark remark = prot.remark
if(remark is not None): if(remark is not None):
out_pdb_lines.append(f"REMARK {remark}") out_pdb_lines.append(f"REMARK {remark}")
...@@ -341,7 +347,7 @@ def to_pdb(prot: Protein) -> str: ...@@ -341,7 +347,7 @@ def to_pdb(prot: Protein) -> str:
0 0
] # Protein supports only C, N, O, S, this works. ] # Protein supports only C, N, O, S, this works.
charge = "" charge = ""
chain_tag = "A" chain_tag = "A"
if(chain_index is not None): if(chain_index is not None):
chain_tag = chain_tags[chain_index[i]] chain_tag = chain_tags[chain_index[i]]
...@@ -385,6 +391,134 @@ def to_pdb(prot: Protein) -> str: ...@@ -385,6 +391,134 @@ def to_pdb(prot: Protein) -> str:
return "\n".join(pdb_lines) return "\n".join(pdb_lines)
def to_modelcif(prot: Protein) -> str:
"""
Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates
will be treated as the same polymer entity. But note that if chains differ in modelled regions,
no attempt is made at identifying them as a single polymer entity.
Args:
prot: The protein to convert to PDB.
Returns:
ModelCIF string.
"""
restypes = residue_constants.restypes + ["X"]
atom_types = residue_constants.atom_types
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
n = aatype.shape[0]
if chain_index is None:
chain_index = [0 for i in range(n)]
system = modelcif.System(title='OpenFold prediction')
# Finding chains and creating entities
seqs = {}
seq = []
last_chain_idx = None
for i in range(n):
if last_chain_idx is not None and last_chain_idx != chain_index[i]:
seqs[last_chain_idx] = seq
seq = []
seq.append(restypes[aatype[i]])
last_chain_idx = chain_index[i]
# finally add the last chain
seqs[last_chain_idx] = seq
# now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions)
unique_seqs = {}
for chain_idx, seq_list in seqs.items():
seq = "".join(seq_list)
if seq in unique_seqs:
unique_seqs[seq].append(chain_idx)
else:
unique_seqs[seq] = [chain_idx]
# adding 1 entity per unique sequence
entities_map = {}
for key, value in unique_seqs.items():
model_e = modelcif.Entity(key, description='Model subunit')
for chain_idx in value:
entities_map[chain_idx] = model_e
chain_tags = string.ascii_uppercase
asym_unit_map = {}
for chain_idx in set(chain_index):
# Define the model assembly
chain_id = chain_tags[chain_idx]
asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id)
asym_unit_map[chain_idx] = asym
modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly')
class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
name = "pLDDT"
software = None
description = "Predicted lddt"
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
name = "pLDDT"
software = None
description = "Global pLDDT, mean of per-residue pLDDTs"
class _MyModel(modelcif.model.AbInitioModel):
def get_atoms(self):
# Add all atom sites.
for i in range(n):
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
element = atom_name[0] # Protein supports only C, N, O, S, this works.
yield modelcif.model.Atom(
asym_unit=asym_unit_map[chain_index[i]], type_symbol=element,
seq_id=residue_index[i], atom_id=atom_name,
x=pos[0], y=pos[1], z=pos[2],
het=False, biso=b_factor, occupancy=1.00)
def add_scores(self):
# local scores
plddt_per_residue = {}
for i in range(n):
for mask, b_factor in zip(atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
# add 1 per residue, not 1 per atom
if chain_index[i] not in plddt_per_residue:
# first time a chain index is seen: add the key and start the residue dict
plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
if residue_index[i] not in plddt_per_residue[chain_index[i]]:
plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
plddts = []
for chain_idx in plddt_per_residue:
for residue_idx in plddt_per_residue[chain_idx]:
plddt = plddt_per_residue[chain_idx][residue_idx]
plddts.append(plddt)
self.qa_metrics.append(
_LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt))
# global score
self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))
# Add the model and modeling protocol to the file and write them out:
model = _MyModel(assembly=modeled_assembly, name='Best scoring model')
model.add_scores()
model_group = modelcif.model.ModelGroup([model], name='All models')
system.model_groups.append(model_group)
fh = io.StringIO()
modelcif.dumper.write(fh, [system])
return fh.getvalue()
def ideal_atom_mask(prot: Protein) -> np.ndarray: def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask. """Computes an ideal atom mask.
......
...@@ -524,9 +524,6 @@ def run_pipeline( ...@@ -524,9 +524,6 @@ def run_pipeline(
_check_residues_are_well_defined(prot) _check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks) pdb_string = clean_protein(prot, checks=checks)
# We keep the input around to restore metadata deleted by the relaxer
input_prot = prot
exclude_residues = exclude_residues or [] exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues) exclude_residues = set(exclude_residues)
violations = np.inf violations = np.inf
......
...@@ -57,7 +57,7 @@ class AmberRelaxation(object): ...@@ -57,7 +57,7 @@ class AmberRelaxation(object):
self._use_gpu = use_gpu self._use_gpu = use_gpu
def process( def process(
self, *, prot: protein.Protein self, *, prot: protein.Protein, cif_output: bool
) -> Tuple[str, Dict[str, Any], np.ndarray]: ) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline( out = amber_minimize.run_pipeline(
...@@ -89,5 +89,11 @@ class AmberRelaxation(object): ...@@ -89,5 +89,11 @@ class AmberRelaxation(object):
] ]
min_pdb = protein.add_pdb_headers(prot, min_pdb) min_pdb = protein.add_pdb_headers(prot, min_pdb)
output_str = min_pdb
if cif_output:
# TODO the model cif will be missing some metadata like headers (PARENTs and
# REMARK with some details of the run, like num of recycles)
final_prot = protein.from_pdb_string(min_pdb)
output_str = protein.to_modelcif(final_prot)
return min_pdb, debug_data, violations return output_str, debug_data, violations
...@@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult ...@@ -228,7 +228,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
return unrelaxed_protein return unrelaxed_protein
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name): def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name, cif_output):
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=(model_device != "cpu"), use_gpu=(model_device != "cpu"),
**config.relax, **config.relax,
...@@ -239,7 +239,8 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out ...@@ -239,7 +239,8 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
if "cuda" in model_device: if "cuda" in model_device:
device_no = model_device.split(":")[-1] device_no = model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) # the struct_str will contain either a PDB-format or a ModelCIF format string
struct_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein, cif_output=cif_output)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t relaxation_time = time.perf_counter() - t
...@@ -247,10 +248,13 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out ...@@ -247,10 +248,13 @@ def relax_protein(config, model_device, unrelaxed_protein, output_directory, out
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json")) update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
# Save the relaxed PDB. # Save the relaxed PDB.
suffix = "_relaxed.pdb"
if cif_output:
suffix = "_relaxed.cif"
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb' output_directory, f'{output_name}{suffix}'
) )
with open(relaxed_output_path, 'w') as fp: with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str) fp.write(struct_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...") logger.info(f"Relaxed output written to {relaxed_output_path}...")
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".") ...@@ -35,7 +35,7 @@ torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0]) torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1]) torch_minor_version = int(torch_versions[1])
if( if(
torch_major_version > 1 or torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12) (torch_major_version == 1 and torch_minor_version >= 12)
): ):
# Gives a large speedup on Ampere-class GPUs # Gives a large speedup on Ampere-class GPUs
...@@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -70,7 +70,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)): if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
...@@ -141,13 +141,13 @@ def main(args): ...@@ -141,13 +141,13 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if(args.trace_model):
if(not config.data.predict.fixed_size): if(not config.data.predict.fixed_size):
raise ValueError( raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config" "Tracing requires that fixed_size mode be enabled in the config"
) )
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
...@@ -165,10 +165,10 @@ def main(args): ...@@ -165,10 +165,10 @@ def main(args):
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(2**32) random_seed = random.randrange(2**32)
np.random.seed(random_seed) np.random.seed(random_seed)
torch.manual_seed(random_seed + 1) torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -183,7 +183,7 @@ def main(args): ...@@ -183,7 +183,7 @@ def main(args):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read() data = fp.read()
tags, seqs = parse_fasta(data) tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
...@@ -206,10 +206,10 @@ def main(args): ...@@ -206,10 +206,10 @@ def main(args):
output_name = f'{tag}_{args.config_preset}' output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None: if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed # Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None) feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None): if(feature_dict is None):
feature_dict = generate_feature_dict( feature_dict = generate_feature_dict(
...@@ -234,7 +234,7 @@ def main(args): ...@@ -234,7 +234,7 @@ def main(args):
) )
processed_feature_dict = { processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device) k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items() for k,v in processed_feature_dict.items()
} }
...@@ -255,34 +255,40 @@ def main(args): ...@@ -255,34 +255,40 @@ def main(args):
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map( processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()), lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict processed_feature_dict
) )
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output( unrelaxed_protein = prep_output(
out, out,
processed_feature_dict, processed_feature_dict,
feature_dict, feature_dict,
feature_processor, feature_processor,
args.config_preset, args.config_preset,
args.multimer_ri_gap, args.multimer_ri_gap,
args.subtract_plddt args.subtract_plddt
) )
unrelaxed_file_suffix = "_unrelaxed.pdb"
if args.cif_output:
unrelaxed_file_suffix = "_unrelaxed.cif"
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb' output_directory, f'{output_name}{unrelaxed_file_suffix}'
) )
with open(unrelaxed_output_path, 'w') as fp: with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein)) if args.cif_output:
fp.write(protein.to_modelcif(unrelaxed_protein))
else:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation: if not args.skip_relaxation:
# Relax the prediction. # Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name) relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
if args.save_outputs: if args.save_outputs:
output_dict_path = os.path.join( output_dict_path = os.path.join(
...@@ -373,12 +379,16 @@ if __name__ == "__main__": ...@@ -373,12 +379,16 @@ if __name__ == "__main__":
"--long_sequence_inference", action="store_true", default=False, "--long_sequence_inference", action="store_true", default=False,
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details""" help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
) )
parser.add_argument(
"--cif_output", action="store_true", default=False,
help="Output predicted models in ModelCIF format instead of PDB format (default)"
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None): if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join( args.jax_param_path = os.path.join(
"openfold", "resources", "params", "openfold", "resources", "params",
"params_" + args.config_preset + ".npz" "params_" + args.config_preset + ".npz"
) )
......
...@@ -106,7 +106,7 @@ def main(args): ...@@ -106,7 +106,7 @@ def main(args):
logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name) relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment