Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
import argparse import argparse
import ctypes
from datetime import date from datetime import date
import sys
def add_data_args(parser: argparse.ArgumentParser): def add_data_args(parser: argparse.ArgumentParser):
...@@ -43,7 +45,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -43,7 +45,7 @@ def add_data_args(parser: argparse.ArgumentParser):
'--kalign_binary_path', type=str, default='/usr/bin/kalign' '--kalign_binary_path', type=str, default='/usr/bin/kalign'
) )
parser.add_argument( parser.add_argument(
'--max_template_date', type=str, '--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"), default=date.today().strftime("%Y-%m-%d"),
) )
parser.add_argument( parser.add_argument(
...@@ -52,3 +54,67 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -52,3 +54,67 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'--release_dates_path', type=str, default=None '--release_dates_path', type=str, default=None
) )
def get_nvidia_cc():
"""
Returns a tuple containing the Compute Capability of the first GPU
installed in the system (formatted as a tuple of strings) and an error
message. When the former is provided, the latter is None, and vice versa.
Adapted from script by Jan Schlüte t
https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
CUDA_SUCCESS = 0
libnames = [
'libcuda.so',
'libcuda.dylib',
'cuda.dll',
'/usr/local/cuda/compat/libcuda.so', # For Docker
]
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
except OSError:
continue
else:
break
else:
return None, "Could not load any of: " + ' '.join(libnames)
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
result = ctypes.c_int()
device = ctypes.c_int()
error_str = ctypes.c_char_p()
result = cuda.cuInit(0)
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
if error_str.value:
return None, error_str.value.decode()
else:
return None, "Unknown error: cuInit returned %d" % result
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
return None, error_str.value.decode()
if nGpus.value < 1:
return None, "No GPUs detected"
result = cuda.cuDeviceGet(ctypes.byref(device), 0)
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
return None, error_str.value.decode()
if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != CUDA_SUCCESS:
return None, "Compute Capability not found"
major = cc_major.value
minor = cc_minor.value
return (major, minor), None
...@@ -13,6 +13,7 @@ import glob ...@@ -13,6 +13,7 @@ import glob
import math import math
import os import os
from collections import OrderedDict from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
...@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): ...@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,7 +16,9 @@ import os ...@@ -16,7 +16,9 @@ import os
from setuptools import setup, Extension, find_packages from setuptools import setup, Extension, find_packages
import subprocess import subprocess
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from scripts.utils import get_nvidia_cc
version_dependent_macros = [ version_dependent_macros = [
...@@ -26,48 +28,56 @@ version_dependent_macros = [ ...@@ -26,48 +28,56 @@ version_dependent_macros = [
] ]
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-std=c++14',
'-maxrregcount=50', '-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda' '--expt-extended-lambda'
] ]
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) if cuda_dir==None:
output = raw_output.split() print("CUDA is not found, cpu version is installed")
release_idx = output.index("release") + 1 return None, -1, 0
release = output[release_idx].split(".") else:
bare_metal_major = release[0] raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
bare_metal_minor = release[1][0] output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_major, bare_metal_minor compute_capabilities = set([
(3, 7), # K80, e.g.
(5, 2), # Titan X
(6, 1), # GeForce 1000-series
])
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70'] compute_capabilities.add((7, 0))
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') compute_capabilities.add((8, 0))
cc_flag.append('arch=compute_80,code=sm_80')
compute_capability, _ = get_nvidia_cc()
if compute_capability is not None:
compute_capabilities = set([compute_capability])
cc_flag = []
for major, minor in list(compute_capabilities):
cc_flag.extend([
'-gencode',
f'arch=compute_{major}{minor},code=sm_{major}{minor}',
])
extra_cuda_flags += cc_flag extra_cuda_flags += cc_flag
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
setup( if bare_metal_major != -1:
name='openfold', modules = [CUDAExtension(
version='0.1.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=[CUDAExtension(
name="attn_core_inplace_cuda", name="attn_core_inplace_cuda",
sources=[ sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp", "openfold/utils/kernel/csrc/softmax_cuda.cpp",
...@@ -75,34 +85,51 @@ setup( ...@@ -75,34 +85,51 @@ setup(
], ],
include_dirs=[ include_dirs=[
os.path.join( os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.dirname(os.path.abspath(__file__)),
'openfold/utils/kernel/csrc/' 'openfold/utils/kernel/csrc/'
) )
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros, 'cxx': ['-O3'] + version_dependent_macros,
'nvcc': ( 'nvcc': (
['-O3', '--use_fast_math'] + ['-O3', '--use_fast_math'] +
version_dependent_macros + version_dependent_macros +
extra_cuda_flags extra_cuda_flags
), ),
} }
)], )]
else:
modules = [CppExtension(
name="attn_core_inplace_cuda",
sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp",
"openfold/utils/kernel/csrc/softmax_cuda_stub.cpp",
],
extra_compile_args={
'cxx': ['-O3'],
}
)]
setup(
name='openfold',
version='1.0.1',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=modules,
cmdclass={'build_ext': BuildExtension}, cmdclass={'build_ext': BuildExtension},
install_requires=[
'torch',
'deepspeed',
'biopython',
'ml-collections',
'numpy',
'scipy',
'pytorch_lightning',
'dm-tree',
],
classifiers=[ classifiers=[
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7,' 'Programming Language :: Python :: 3.7,'
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
], ],
) )
...@@ -10,6 +10,7 @@ consts = mlc.ConfigDict( ...@@ -10,6 +10,7 @@ consts = mlc.ConfigDict(
"n_seq": 13, "n_seq": 13,
"n_templ": 3, "n_templ": 3,
"n_extra": 17, "n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4, "eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for # For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values. # everyone if these take their real values.
......
...@@ -30,7 +30,10 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4): ...@@ -30,7 +30,10 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces = [] pieces = []
asym_ids = [] asym_ids = []
for idx in range(n_chain - 1): for idx in range(n_chain - 1):
piece = randint(min_chain_len, (n_res - sum(pieces) - n_chain + idx - min_chain_len)) n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len:
break
piece = randint(min_chain_len, n_stop)
pieces.append(piece) pieces.append(piece)
asym_ids.extend(piece * [idx]) asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1]) asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1])
......
...@@ -45,7 +45,7 @@ class TestDataTransforms(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestDataTransforms(unittest.TestCase):
template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_() template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
template_seq_one_hot.scatter_(1, template_seq, 1) template_seq_one_hot.scatter_(1, template_seq, 1)
template_aatype = template_seq_one_hot.clone().detach().unsqueeze(0) template_aatype = template_seq_one_hot.clone().detach().unsqueeze(0)
protein = {'template_aatype': template_aatype} protein = {'template_aatype': template_aatype, 'aatype': template_aatype}
protein = fix_templates_aatype(protein) protein = fix_templates_aatype(protein)
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2]) template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours)) assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
...@@ -171,7 +171,10 @@ class TestDataTransforms(unittest.TestCase): ...@@ -171,7 +171,10 @@ class TestDataTransforms(unittest.TestCase):
with open('tests/test_data/features.pkl', 'rb') as file: with open('tests/test_data/features.pkl', 'rb') as file:
features = pickle.load(file) features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)} protein = {
'msa': torch.tensor(features['msa'], dtype=torch.int64),
'aatype': torch.tensor(features['aatype'], dtype=torch.int64),
}
protein = make_hhblits_profile(protein) protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42) protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
......
...@@ -50,18 +50,18 @@ class TestInputEmbedder(unittest.TestCase): ...@@ -50,18 +50,18 @@ class TestInputEmbedder(unittest.TestCase):
entity_id = asym_id entity_id = asym_id
sym_id = torch.zeros_like(entity_id) sym_id = torch.zeros_like(entity_id)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa}
if consts.is_multimer: if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m, ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx, max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative, use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain) max_relative_chain=max_relative_chain)
batch.update({"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}) batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
msa_emb, pair_emb = ie(batch)
else: else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k) ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
msa_emb, pair_emb = ie(batch)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m)) self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
...@@ -132,13 +132,31 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -132,13 +132,31 @@ class TestEvoformerStack(unittest.TestCase):
torch.as_tensor(masks["pair"]).cuda(), torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4, chunk_size=4,
_mask_trans=False, _mask_trans=False,
inplace_safe=False,
) )
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
torch.as_tensor(activations["msa"]).cuda(),
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4,
_mask_trans=False,
inplace_safe=True,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
class TestExtraMSAStack(unittest.TestCase): class TestExtraMSAStack(unittest.TestCase):
...@@ -270,9 +288,6 @@ class TestMSATransition(unittest.TestCase): ...@@ -270,9 +288,6 @@ class TestMSATransition(unittest.TestCase):
.cpu() .cpu()
) )
print(out_gt)
print(out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......
...@@ -34,7 +34,7 @@ from openfold.utils.tensor_utils import ( ...@@ -34,7 +34,7 @@ from openfold.utils.tensor_utils import (
) )
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_affines_4x4 from tests.data_utils import random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -170,14 +170,21 @@ class TestFeats(unittest.TestCase): ...@@ -170,14 +170,21 @@ class TestFeats(unittest.TestCase):
out_gt = f.apply({}, None, **batch) out_gt = f.apply({}, None, **batch)
if consts.is_multimer: if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
to_tensor = (lambda t: torch.tensor(np.array(t)) to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array) if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())).view(*t.shape[:2], 12)) else torch.tensor(np.array(t.to_array())))
else: else:
to_tensor = lambda t: torch.tensor(np.array(t)) to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()} out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def rigid3x4_to_4x4(rigid3arr):
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
four_by_four[..., :3, :4] = rigid3arr
four_by_four[..., 3, 3] = 1
return four_by_four
def flat12_to_4x4(flat12): def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3) rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:] trans = flat12[..., 9:]
...@@ -189,10 +196,12 @@ class TestFeats(unittest.TestCase): ...@@ -189,10 +196,12 @@ class TestFeats(unittest.TestCase):
return four_by_four return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4( convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
out_gt["rigidgroups_gt_frames"] = convert_func(
out_gt["rigidgroups_gt_frames"] out_gt["rigidgroups_gt_frames"]
) )
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4( out_gt["rigidgroups_alt_gt_frames"] = convert_func(
out_gt["rigidgroups_alt_gt_frames"] out_gt["rigidgroups_alt_gt_frames"]
) )
...@@ -278,13 +287,21 @@ class TestFeats(unittest.TestCase): ...@@ -278,13 +287,21 @@ class TestFeats(unittest.TestCase):
) )
# Convert the Rigids to 4x4 transformation tensors # Convert the Rigids to 4x4 transformation tensors
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot)) out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
trans_gt = list( out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
) if consts.is_multimer:
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1) rots_gt = torch.as_tensor(np.array(out_gt_rot))
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3) trans_gt = torch.as_tensor(np.array(out_gt_trans))
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1) else:
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1) transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4)) bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1 bottom_row[..., 3] = 1
...@@ -321,9 +338,6 @@ class TestFeats(unittest.TestCase): ...@@ -321,9 +338,6 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions), torch.tensor(restype_atom14_rigid_group_positions),
) )
if consts.is_multimer:
xyz = xyz.to_tensor()
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3)) self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import torch import torch
import numpy as np import numpy as np
from pathlib import Path
import unittest import unittest
import ml_collections as mlc import ml_collections as mlc
...@@ -301,7 +302,8 @@ class TestLoss(unittest.TestCase): ...@@ -301,7 +302,8 @@ class TestLoss(unittest.TestCase):
def test_find_structural_violations_compare(self): def test_find_structural_violations_compare(self):
def run_fsv(batch, pos, config): def run_fsv(batch, pos, config):
cwd = os.getcwd() cwd = os.getcwd()
os.chdir("tests/test_data") fpath = Path(__file__).parent.resolve() / "test_data"
os.chdir(str(fpath))
if consts.is_multimer: if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos) atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
...@@ -436,7 +438,7 @@ class TestLoss(unittest.TestCase): ...@@ -436,7 +438,7 @@ class TestLoss(unittest.TestCase):
"true_msa": np.random.randint(0, 21, (n_res, n_seq)), "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype( "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32 np.float32
), )
} }
out_gt = f.apply({}, None, value, batch)["loss"] out_gt = f.apply({}, None, value, batch)["loss"]
...@@ -448,7 +450,9 @@ class TestLoss(unittest.TestCase): ...@@ -448,7 +450,9 @@ class TestLoss(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
out_repro = masked_msa_loss( out_repro = masked_msa_loss(
value["logits"], value["logits"],
**batch, batch["true_msa"],
batch["bert_mask"],
consts.msa_logits
) )
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
...@@ -903,6 +907,9 @@ class TestLoss(unittest.TestCase): ...@@ -903,6 +907,9 @@ class TestLoss(unittest.TestCase):
), ),
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
def _build_extra_feats_np(): def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b) b = data_transforms.make_atom14_masks(b)
...@@ -943,7 +950,7 @@ class TestLoss(unittest.TestCase): ...@@ -943,7 +950,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.") @unittest.skipIf(consts.is_multimer or "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self): def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error c_tm = config.model.heads.predicted_aligned_error
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pathlib import Path
import pickle import pickle
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -54,7 +55,7 @@ class TestModel(unittest.TestCase): ...@@ -54,7 +55,7 @@ class TestModel(unittest.TestCase):
n_res = consts.n_res n_res = consts.n_res
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config(consts.model) c = model_config(consts.model, train=True)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
...@@ -68,6 +69,7 @@ class TestModel(unittest.TestCase): ...@@ -68,6 +69,7 @@ class TestModel(unittest.TestCase):
).float() ).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
...@@ -95,11 +97,14 @@ class TestModel(unittest.TestCase): ...@@ -95,11 +97,14 @@ class TestModel(unittest.TestCase):
out = model(batch) out = model(batch)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer, "Additional changes required for multimer.")
def test_compare(self): def test_compare(self):
#TODO: Fix test data for multimer MSA features
def run_alphafold(batch): def run_alphafold(batch):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
model = self.am_modules.AlphaFold(config.model) model = self.am_modules.AlphaFold(config.model)
return model( return model(
batch=batch, batch=batch,
is_training=False, is_training=False,
...@@ -110,7 +115,8 @@ class TestModel(unittest.TestCase): ...@@ -110,7 +115,8 @@ class TestModel(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights("") params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp: fpath = Path(__file__).parent.resolve() / "test_data/sample_feats.pickle"
with open(str(fpath), "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch) out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
...@@ -150,6 +156,4 @@ class TestModel(unittest.TestCase): ...@@ -150,6 +156,4 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
print(torch.mean(torch.abs(out_gt - out_repro)))
print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
) )
).cpu() ).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
class TestMSAColumnAttention(unittest.TestCase): class TestMSAColumnAttention(unittest.TestCase):
...@@ -158,9 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -158,9 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
) )
).cpu() ).cpu()
print(torch.mean(torch.abs(out_gt - out_repro))) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class TestMSAColumnGlobalAttention(unittest.TestCase): class TestMSAColumnGlobalAttention(unittest.TestCase):
......
...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets # Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps. # a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4)) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,54 +15,33 @@ ...@@ -15,54 +15,33 @@
import torch import torch
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import Attention
Attention
)
from tests.config import consts from tests.config import consts
class TestLMA(unittest.TestCase): class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self): def test_lma_vs_attention(self):
batch_size = consts.batch_size batch_size = consts.batch_size
c_hidden = 32 c_hidden = 32
n = 2**12 n = 2 ** 12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda() kv = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)] bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias] bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention( a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad(): with torch.no_grad():
l = lma(q, k, v, biases=bias, use_lma=True, q_chunk_size=1024, kv_chunk_size=4096) l = a(q, kv, biases=bias, use_lma=True)
real = a(q, k, v, biases=bias) real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -99,7 +99,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -99,7 +99,7 @@ class TestStructureModule(unittest.TestCase):
z = torch.rand((batch_size, n, n, c_z)) z = torch.rand((batch_size, n, n, c_z))
f = torch.randint(low=0, high=21, size=(batch_size, n)).long() f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
out = sm(s, z, f) out = sm({"single": s, "pair": z}, f)
if consts.is_multimer: if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4)) self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
...@@ -183,10 +183,13 @@ class TestStructureModule(unittest.TestCase): ...@@ -183,10 +183,13 @@ class TestStructureModule(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module( out_repro = model.structure_module(
torch.as_tensor(representations["single"]).cuda(), {
torch.as_tensor(representations["pair"]).cuda(), "single": torch.as_tensor(representations["single"]).cuda(),
"pair": torch.as_tensor(representations["pair"]).cuda(),
},
torch.as_tensor(batch["aatype"]).cuda(), torch.as_tensor(batch["aatype"]).cuda(),
mask=torch.as_tensor(batch["seq_mask"]).cuda(), mask=torch.as_tensor(batch["seq_mask"]).cuda(),
inplace_safe=False,
) )
out_repro = out_repro["positions"][-1].cpu() out_repro = out_repro["positions"][-1].cpu()
...@@ -286,7 +289,7 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -286,7 +289,7 @@ class TestInvariantPointAttention(unittest.TestCase):
if consts.is_multimer: if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines) rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4( transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float() torch.as_tensor(affines).float().cuda()
) )
sample_affine = rigids sample_affine = rigids
else: else:
......
...@@ -206,7 +206,7 @@ class Template(unittest.TestCase): ...@@ -206,7 +206,7 @@ class Template(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def test_template_embedding(pair, batch, mask_2d): def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
te = self.am_modules.TemplateEmbedding( te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template, config.model.embeddings_and_evoformer.template,
...@@ -214,7 +214,7 @@ class Template(unittest.TestCase): ...@@ -214,7 +214,7 @@ class Template(unittest.TestCase):
) )
if consts.is_multimer: if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=multichain_mask_2d, is_training=False) act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else: else:
act = te(pair, batch, mask_2d, is_training=False) act = te(pair, batch, mask_2d, is_training=False)
return act return act
...@@ -228,12 +228,12 @@ class Template(unittest.TestCase): ...@@ -228,12 +228,12 @@ class Template(unittest.TestCase):
batch = random_template_feats(n_templ, n_res) batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"] batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer: if consts.is_multimer:
asym_id = batch['asym_id'][0] asym_id = batch['asym_id'][0]
multichain_mask_2d = ( multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :] asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32) ).astype(np.float32)
batch["multichain_mask_2d"] = multichain_mask_2d
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)] # Fetch pretrained parameters (but only from one block)]
...@@ -242,7 +242,7 @@ class Template(unittest.TestCase): ...@@ -242,7 +242,7 @@ class Template(unittest.TestCase):
) )
out_gt = f.apply( out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready() ).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -259,7 +259,9 @@ class Template(unittest.TestCase): ...@@ -259,7 +259,9 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=consts.chunk_size, chunk_size=consts.chunk_size,
multichain_mask_2d=multichain_mask_2d, multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
use_lma=False,
inplace_safe=False
) )
else: else:
out_repro = model.template_embedder( out_repro = model.template_embedder(
...@@ -267,7 +269,9 @@ class Template(unittest.TestCase): ...@@ -267,7 +269,9 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=consts.chunk_size chunk_size=consts.chunk_size,
use_lma=False,
inplace_safe=False
) )
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro["template_pair_embedding"]
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import torch import torch
import numpy as np import numpy as np
...@@ -89,13 +90,19 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -89,13 +90,19 @@ class TestTriangularAttention(unittest.TestCase):
if starting if starting
else model.evoformer.blocks[0].pair_stack.tri_att_end else model.evoformer.blocks[0].pair_stack.tri_att_end
) )
# To save memory, the full model transposes inputs outside of the
# triangle attention module. We adjust the module here.
module = copy.deepcopy(module)
module.starting = starting
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self): def test_shape(self):
c_z = consts.c_z c_z = consts.c_z
c = 11 c = 11
outgoing = True
tm = TriangleMultiplicationOutgoing( tm = TriangleMultiplicationOutgoing(
c_z, c_z,
c, c,
outgoing,
) )
n_res = consts.c_z n_res = consts.c_z
...@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self): def test_tri_mul_out_compare(self):
...@@ -106,6 +105,39 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -106,6 +105,39 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_tri_mul_in_compare(self): def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True) self._tri_mul_compare(incoming=True)
def _tri_mul_inplace(self, incoming=False):
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=False,
).cpu()
# This has to come second because inference mode is in-place
out_inplace = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=2,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
def test_tri_mul_out_inference(self):
self._tri_mul_inplace()
def test_tri_mul_in_inference(self):
self._tri_mul_inplace(incoming=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import ( ...@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot, quat_to_rot,
rot_to_quat, rot_to_quat,
) )
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
......
import argparse
import os
import logging
import random
import numpy
import torch
from openfold.config import model_config
from openfold.data import feature_pipeline
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
from openfold.np import protein
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
relax_protein
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
numpy.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
with open(args.input_fasta) as fasta_file:
tags, sequences = parse_fasta(fasta_file.read())
if len(sequences) != 1:
raise ValueError("the threading script can only process a single sequence")
query_sequence = sequences[0]
query_tag = tags[0]
feature_dict = make_sequence_features_with_custom_template(
query_sequence,
args.input_mmcif,
args.template_id,
args.chain_id,
args.kalign_binary_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
output_name = f'{query_tag}_{args.config_preset}'
for model, output_directory in model_generator:
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: numpy.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args.subtract_plddt
)
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
parser.add_argument(
"--chain_id", type=str,
help="""The chain ID of the chain in the template to use"""
)
parser.add_argument(
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment