"examples/vscode:/vscode.git/clone" did not exist on "f955203309a06824a1d4c687c8b11131dfaf8695"
Unverified Commit 82397498 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #206 from CyrusBiotechnology/custom-template

adding a script for threading a sequence onto a structure
parents 9dd9cea4 25feff5f
......@@ -21,6 +21,7 @@ from typing import Mapping, Optional, Sequence, Any
import numpy as np
from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.templates import get_custom_template_features
from openfold.data.tools import jackhmmer, hhblits, hhsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
......@@ -259,6 +260,41 @@ def make_msa_features(
return features
def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str) -> FeatureDict:
"""
process a single fasta file using features derived from a single template rather than an alignment
"""
num_res = len(sequence)
sequence_features = make_sequence_features(
sequence=sequence,
description=pdb_id,
num_res=num_res,
)
msa_data = [[sequence]]
deletion_matrix = [[[0 for _ in sequence]]]
msa_features = make_msa_features(msa_data, deletion_matrix)
template_features = get_custom_template_features(
mmcif_path=mmcif_path,
query_sequence=sequence,
pdb_id=pdb_id,
chain_id=chain_id,
kalign_binary_path=kalign_binary_path
)
return {
**sequence_features,
**msa_features,
**template_features.features
}
class AlignmentRunner:
"""Runs alignment tools and saves the results"""
def __init__(
......
......@@ -913,6 +913,56 @@ def _process_single_hit(
return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import os
import glob
import importlib as importlib
from . import kernel
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
] + ["kernel"]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
import json
import logging
import os
import re
import time
import numpy
import torch
from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein
from openfold.np.relax import relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def load_models_from_command_line(config, model_device, openfold_checkpoint_path, jax_param_path, output_dir):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(openfold_checkpoint_path, jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if jax_param_path:
for path in jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=model_version
)
model = model.to(model_device)
logger.info(
f"Successfully loaded JAX parameters at {path}..."
)
output_directory = make_output_directory(output_dir, model_basename, multiple_model_mode)
yield model, output_directory
if openfold_checkpoint_path:
for path in openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
output_dir,
checkpoint_basename + ".pt",
)
if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(model_device)
logger.info(
f"Loaded OpenFold parameters at {path}..."
)
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not jax_param_path and not openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
return tags, seqs
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""
Write dictionary of one or more run step times to a file
"""
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(timing_dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
def run_model(model, batch, tag, output_dir):
with torch.no_grad():
# Temporarily disable templates if there aren't any in the batch
template_enabled = model.config.template.enabled
model.config.template.enabled = template_enabled and any([
"template_" in k for k in batch
])
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json"))
model.config.template.enabled = template_enabled
return out
def prep_output(out, batch, feature_dict, feature_processor, config_preset, multimer_ri_gap, subtract_plddt):
plddt = out["plddt"]
plddt_b_factors = numpy.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
if subtract_plddt:
plddt_b_factors = 100 - plddt_b_factors
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if feature_processor.config.common.use_templates and "template_domain_names" in feature_dict:
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if "template_chain_index" in feature_dict:
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={config_preset}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - numpy.arange(ri.shape[0])) / multimer_ri_gap
chain_index = chain_index.astype(numpy.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if c != cur_chain:
cur_chain = c
prev_chain_max = i + cur_chain * multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(model_device != "cpu"),
**config.relax,
)
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in model_device:
device_no = model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
\ No newline at end of file
......@@ -13,26 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from copy import deepcopy
from datetime import date
import logging
import math
import numpy as np
import os
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
update_timings, relax_protein
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import pickle
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
import random
import sys
import time
import torch
import re
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
......@@ -46,15 +43,11 @@ if(
torch.set_grad_enabled(False)
from openfold.config import model_config, NUM_RES
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
......@@ -107,102 +100,6 @@ def round_up_seqlen(seqlen):
return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL
def run_model(model, batch, tag, args):
with torch.no_grad():
# Temporarily disable templates if there aren't any in the batch
template_enabled = model.config.template.enabled
model.config.template.enabled = template_enabled and any([
"template_" in k for k in batch
])
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(args.output_dir, "timings.json"))
model.config.template.enabled = template_enabled
return out
def prep_output(out, batch, feature_dict, feature_processor, args):
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
if(args.subtract_plddt):
plddt_b_factors = 100 - plddt_b_factors
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if(feature_processor.config.common.use_templates and "template_domain_names" in feature_dict):
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if("template_chain_index" in feature_dict):
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.config_preset}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - np.arange(ri.shape[0])) / args.multimer_ri_gap
chain_index = chain_index.astype(np.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if(c != cur_chain):
cur_chain = c
prev_chain_max = i + cur_chain * args.multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
return tags, seqs
def generate_feature_dict(
tags,
seqs,
......@@ -235,98 +132,6 @@ def generate_feature_dict(
return feature_dict
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def load_models_from_command_line(args, config):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(args.openfold_checkpoint_path, args.jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if args.jax_param_path:
for path in args.jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=model_version
)
model = model.to(args.model_device)
logger.info(
f"Successfully loaded JAX parameters at {path}..."
)
output_directory = make_output_directory(args.output_dir, model_basename, multiple_model_mode)
yield model, output_directory
if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
args.output_dir,
checkpoint_basename + ".pt",
)
if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(args.model_device)
logger.info(
f"Loaded OpenFold parameters at {path}..."
)
output_directory = make_output_directory(args.output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
......@@ -389,7 +194,13 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {}
for model, output_directory in load_models_from_command_line(args, config):
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
for model, output_directory in model_generator:
cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets:
output_name = f'{tag}_{args.config_preset}'
......@@ -440,7 +251,7 @@ def main(args):
)
cur_tracing_interval = rounded_seqlen
out = run_model(model, processed_feature_dict, tag, args)
out = run_model(model, processed_feature_dict, tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
......@@ -454,7 +265,8 @@ def main(args):
processed_feature_dict,
feature_dict,
feature_processor,
args
args.config_preset,
args.multimer_ri_gap
)
unrelaxed_output_path = os.path.join(
......@@ -467,33 +279,9 @@ def main(args):
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in args.model_device:
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(args.output_dir, "timings.json"))
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if args.save_outputs:
output_dict_path = os.path.join(
......@@ -504,22 +292,6 @@ def main(args):
logger.info(f"Model output written to {output_dict_path}...")
def update_timings(dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""Write dictionary of one or more run step times to a file"""
import json
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
import argparse
import os
import logging
import random
import numpy
import torch
from openfold.config import model_config
from openfold.data import feature_pipeline
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
from openfold.np import protein
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
relax_protein
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
numpy.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
with open(args.input_fasta) as fasta_file:
tags, sequences = parse_fasta(fasta_file.read())
if len(sequences) != 1:
raise ValueError("the threading script can only process a single sequence")
query_sequence = sequences[0]
query_tag = tags[0]
feature_dict = make_sequence_features_with_custom_template(
query_sequence,
args.input_mmcif,
args.template_id,
args.chain_id,
args.kalign_binary_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
output_name = f'{query_tag}_{args.config_preset}'
for model, output_directory in model_generator:
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: numpy.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args.subtract_plddt
)
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
parser.add_argument(
"--chain_id", type=str,
help="""The chain ID of the chain in the template to use"""
)
parser.add_argument(
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment