Unverified Commit 82397498 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #206 from CyrusBiotechnology/custom-template

adding a script for threading a sequence onto a structure
parents 9dd9cea4 25feff5f
...@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any ...@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any
import numpy as np import numpy as np
from openfold.data import templates, parsers, mmcif_parsing from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.templates import get_custom_template_features
from openfold.data.tools import jackhmmer, hhblits, hhsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -259,6 +260,41 @@ def make_msa_features( ...@@ -259,6 +260,41 @@ def make_msa_features(
return features return features
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [[sequence]]
deletion_matrix = [[[0 for _ in sequence]]]
msa_features = make_msa_features(msa_data, deletion_matrix)
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner: class AlignmentRunner:
"""Runs alignment tools and saves the results""" """Runs alignment tools and saves the results"""
def __init__( def __init__(
......
...@@ -913,6 +913,56 @@ def _process_single_hit( ...@@ -913,6 +913,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateSearchResult: class TemplateSearchResult:
features: Mapping[str, Any] features: Mapping[str, Any]
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
from . import kernel
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
] + ["kernel"]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import json
import logging
import os
import re
import time
import numpy
import torch
from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein
from openfold.np.relax import relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def load_models_from_command_line(config, model_device, openfold_checkpoint_path, jax_param_path, output_dir):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(openfold_checkpoint_path, jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if jax_param_path:
for path in jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=model_version
)
model = model.to(model_device)
logger.info(
f"Successfully loaded JAX parameters at {path}..."
)
output_directory = make_output_directory(output_dir, model_basename, multiple_model_mode)
yield model, output_directory
if openfold_checkpoint_path:
for path in openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
output_dir,
checkpoint_basename + ".pt",
)
if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(model_device)
logger.info(
f"Loaded OpenFold parameters at {path}..."
)
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not jax_param_path and not openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
return tags, seqs
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""
Write dictionary of one or more run step times to a file
"""
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(timing_dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
def run_model(model, batch, tag, output_dir):
with torch.no_grad():
# Temporarily disable templates if there aren't any in the batch
template_enabled = model.config.template.enabled
model.config.template.enabled = template_enabled and any([
"template_" in k for k in batch
])
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json"))
model.config.template.enabled = template_enabled
return out
def prep_output(out, batch, feature_dict, feature_processor, config_preset, multimer_ri_gap, subtract_plddt):
plddt = out["plddt"]
plddt_b_factors = numpy.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
if subtract_plddt:
plddt_b_factors = 100 - plddt_b_factors
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if feature_processor.config.common.use_templates and "template_domain_names" in feature_dict:
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if "template_chain_index" in feature_dict:
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={config_preset}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - numpy.arange(ri.shape[0])) / multimer_ri_gap
chain_index = chain_index.astype(numpy.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if c != cur_chain:
cur_chain = c
prev_chain_max = i + cur_chain * multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(model_device != "cpu"),
**config.relax,
)
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in model_device:
device_no = model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
\ No newline at end of file
...@@ -13,26 +13,23 @@ ...@@ -13,26 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from copy import deepcopy
from datetime import date
import logging import logging
import math import math
import numpy as np import numpy as np
import os import os
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
update_timings, relax_protein
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO) logger.setLevel(level=logging.INFO)
import pickle import pickle
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
import random import random
import sys
import time import time
import torch import torch
import re
torch_versions = torch.__version__.split(".") torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0]) torch_major_version = int(torch_versions[0])
...@@ -46,15 +43,11 @@ if( ...@@ -46,15 +43,11 @@ if(
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from openfold.config import model_config, NUM_RES from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
...@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen): ...@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen):
return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL
def run_model(model, batch, tag, args):
with torch.no_grad():
# Temporarily disable templates if there aren't any in the batch
template_enabled = model.config.template.enabled
model.config.template.enabled = template_enabled and any([
"template_" in k for k in batch
])
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(args.output_dir, "timings.json"))
model.config.template.enabled = template_enabled
return out
def prep_output(out, batch, feature_dict, feature_processor, args):
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
if(args.subtract_plddt):
plddt_b_factors = 100 - plddt_b_factors
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if(feature_processor.config.common.use_templates and "template_domain_names" in feature_dict):
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if("template_chain_index" in feature_dict):
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.config_preset}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - np.arange(ri.shape[0])) / args.multimer_ri_gap
chain_index = chain_index.astype(np.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if(c != cur_chain):
cur_chain = c
prev_chain_max = i + cur_chain * args.multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
return tags, seqs
def generate_feature_dict( def generate_feature_dict(
tags, tags,
seqs, seqs,
...@@ -235,98 +132,6 @@ def generate_feature_dict( ...@@ -235,98 +132,6 @@ def generate_feature_dict(
return feature_dict return feature_dict
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def load_models_from_command_line(args, config):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(args.openfold_checkpoint_path, args.jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if args.jax_param_path:
for path in args.jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=model_version
)
model = model.to(args.model_device)
logger.info(
f"Successfully loaded JAX parameters at {path}..."
)
output_directory = make_output_directory(args.output_dir, model_basename, multiple_model_mode)
yield model, output_directory
if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
args.output_dir,
checkpoint_basename + ".pt",
)
if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(args.model_device)
logger.info(
f"Loaded OpenFold parameters at {path}..."
)
output_directory = make_output_directory(args.output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def list_files_with_extensions(dir, extensions): def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)] return [f for f in os.listdir(dir) if f.endswith(extensions)]
...@@ -389,7 +194,13 @@ def main(args): ...@@ -389,7 +194,13 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {} feature_dicts = {}
for model, output_directory in load_models_from_command_line(args, config): model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
for model, output_directory in model_generator:
cur_tracing_interval = 0 cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets: for (tag, tags), seqs in sorted_targets:
output_name = f'{tag}_{args.config_preset}' output_name = f'{tag}_{args.config_preset}'
...@@ -440,7 +251,7 @@ def main(args): ...@@ -440,7 +251,7 @@ def main(args):
) )
cur_tracing_interval = rounded_seqlen cur_tracing_interval = rounded_seqlen
out = run_model(model, processed_feature_dict, tag, args) out = run_model(model, processed_feature_dict, tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map( processed_feature_dict = tensor_tree_map(
...@@ -454,7 +265,8 @@ def main(args): ...@@ -454,7 +265,8 @@ def main(args):
processed_feature_dict, processed_feature_dict,
feature_dict, feature_dict,
feature_processor, feature_processor,
args args.config_preset,
args.multimer_ri_gap
) )
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
...@@ -467,33 +279,9 @@ def main(args): ...@@ -467,33 +279,9 @@ def main(args):
logger.info(f"Output written to {unrelaxed_output_path}...") logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation: if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction. # Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...") logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter() relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in args.model_device:
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(args.output_dir, "timings.json"))
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
if args.save_outputs: if args.save_outputs:
output_dict_path = os.path.join( output_dict_path = os.path.join(
...@@ -504,22 +292,6 @@ def main(args): ...@@ -504,22 +292,6 @@ def main(args):
logger.info(f"Model output written to {output_dict_path}...") logger.info(f"Model output written to {output_dict_path}...")
def update_timings(dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""Write dictionary of one or more run step times to a file"""
import json
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
import argparse
import os
import logging
import random
import numpy
import torch
from openfold.config import model_config
from openfold.data import feature_pipeline
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
from openfold.np import protein
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
relax_protein
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
numpy.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
with open(args.input_fasta) as fasta_file:
tags, sequences = parse_fasta(fasta_file.read())
if len(sequences) != 1:
raise ValueError("the threading script can only process a single sequence")
query_sequence = sequences[0]
query_tag = tags[0]
feature_dict = make_sequence_features_with_custom_template(
query_sequence,
args.input_mmcif,
args.template_id,
args.chain_id,
args.kalign_binary_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
output_name = f'{query_tag}_{args.config_preset}'
for model, output_directory in model_generator:
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: numpy.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args.subtract_plddt
)
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
parser.add_argument(
"--chain_id", type=str,
help="""The chain ID of the chain in the template to use"""
)
parser.add_argument(
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment