Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
import argparse
import ctypes
from datetime import date
import sys
def add_data_args(parser: argparse.ArgumentParser):
......@@ -43,7 +45,7 @@ def add_data_args(parser: argparse.ArgumentParser):
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
......@@ -52,3 +54,67 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--release_dates_path', type=str, default=None
)
def get_nvidia_cc():
"""
Returns a tuple containing the Compute Capability of the first GPU
installed in the system (formatted as a tuple of strings) and an error
message. When the former is provided, the latter is None, and vice versa.
Adapted from script by Jan Schlüte t
https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
CUDA_SUCCESS = 0
libnames = [
'libcuda.so',
'libcuda.dylib',
'cuda.dll',
'/usr/local/cuda/compat/libcuda.so', # For Docker
]
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
except OSError:
continue
else:
break
else:
return None, "Could not load any of: " + ' '.join(libnames)
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
result = ctypes.c_int()
device = ctypes.c_int()
error_str = ctypes.c_char_p()
result = cuda.cuInit(0)
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
if error_str.value:
return None, error_str.value.decode()
else:
return None, "Unknown error: cuInit returned %d" % result
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
return None, error_str.value.decode()
if nGpus.value < 1:
return None, "No GPUs detected"
result = cuda.cuDeviceGet(ctypes.byref(device), 0)
if result != CUDA_SUCCESS:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
return None, error_str.value.decode()
if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != CUDA_SUCCESS:
return None, "Compute Capability not found"
major = cc_major.value
minor = cc_minor.value
return (major, minor), None
......@@ -13,6 +13,7 @@ import glob
import math
import os
from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
......@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__":
......
......@@ -16,7 +16,9 @@ import os
from setuptools import setup, Extension, find_packages
import subprocess
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from scripts.utils import get_nvidia_cc
version_dependent_macros = [
......@@ -26,48 +28,56 @@ version_dependent_macros = [
]
extra_cuda_flags = [
'-std=c++14',
'-maxrregcount=50',
'-std=c++14',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
if cuda_dir==None:
print("CUDA is not found, cpu version is installed")
return None, -1, 0
else:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_major, bare_metal_minor
compute_capabilities = set([
(3, 7), # K80, e.g.
(5, 2), # Titan X
(6, 1), # GeForce 1000-series
])
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
compute_capabilities.add((7, 0))
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
compute_capabilities.add((8, 0))
compute_capability, _ = get_nvidia_cc()
if compute_capability is not None:
compute_capabilities = set([compute_capability])
cc_flag = []
for major, minor in list(compute_capabilities):
cc_flag.extend([
'-gencode',
f'arch=compute_{major}{minor},code=sm_{major}{minor}',
])
extra_cuda_flags += cc_flag
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
setup(
name='openfold',
version='0.1.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=[CUDAExtension(
if bare_metal_major != -1:
modules = [CUDAExtension(
name="attn_core_inplace_cuda",
sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp",
......@@ -75,34 +85,51 @@ setup(
],
include_dirs=[
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.path.dirname(os.path.abspath(__file__)),
'openfold/utils/kernel/csrc/'
)
],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc': (
['-O3', '--use_fast_math'] +
version_dependent_macros +
['-O3', '--use_fast_math'] +
version_dependent_macros +
extra_cuda_flags
),
}
)],
)]
else:
modules = [CppExtension(
name="attn_core_inplace_cuda",
sources=[
"openfold/utils/kernel/csrc/softmax_cuda.cpp",
"openfold/utils/kernel/csrc/softmax_cuda_stub.cpp",
],
extra_compile_args={
'cxx': ['-O3'],
}
)]
setup(
name='openfold',
version='1.0.1',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]),
include_package_data=True,
package_data={
"openfold": ['utils/kernel/csrc/*'],
"": ["resources/stereo_chemical_props.txt"]
},
ext_modules=modules,
cmdclass={'build_ext': BuildExtension},
install_requires=[
'torch',
'deepspeed',
'biopython',
'ml-collections',
'numpy',
'scipy',
'pytorch_lightning',
'dm-tree',
],
classifiers=[
'License :: OSI Approved :: Apache Software License',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7,'
'Programming Language :: Python :: 3.7,'
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
......@@ -10,6 +10,7 @@ consts = mlc.ConfigDict(
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
......
......@@ -30,7 +30,10 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces = []
asym_ids = []
for idx in range(n_chain - 1):
piece = randint(min_chain_len, (n_res - sum(pieces) - n_chain + idx - min_chain_len))
n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len:
break
piece = randint(min_chain_len, n_stop)
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1])
......
......@@ -45,7 +45,7 @@ class TestDataTransforms(unittest.TestCase):
template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
template_seq_one_hot.scatter_(1, template_seq, 1)
template_aatype = template_seq_one_hot.clone().detach().unsqueeze(0)
protein = {'template_aatype': template_aatype}
protein = {'template_aatype': template_aatype, 'aatype': template_aatype}
protein = fix_templates_aatype(protein)
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
......@@ -171,7 +171,10 @@ class TestDataTransforms(unittest.TestCase):
with open('tests/test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = {
'msa': torch.tensor(features['msa'], dtype=torch.int64),
'aatype': torch.tensor(features['aatype'], dtype=torch.int64),
}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
......
......@@ -50,18 +50,18 @@ class TestInputEmbedder(unittest.TestCase):
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa}
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch.update({"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id})
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
msa_emb, pair_emb = ie(batch)
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
msa_emb, pair_emb = ie(batch)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
......@@ -132,13 +132,31 @@ class TestEvoformerStack(unittest.TestCase):
torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4,
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
torch.as_tensor(activations["msa"]).cuda(),
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(),
chunk_size=4,
_mask_trans=False,
inplace_safe=True,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -270,9 +288,6 @@ class TestMSATransition(unittest.TestCase):
.cpu()
)
print(out_gt)
print(out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......
......@@ -34,7 +34,7 @@ from openfold.utils.tensor_utils import (
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
from tests.data_utils import random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -170,14 +170,21 @@ class TestFeats(unittest.TestCase):
out_gt = f.apply({}, None, **batch)
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())).view(*t.shape[:2], 12))
else torch.tensor(np.array(t.to_array())))
else:
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def rigid3x4_to_4x4(rigid3arr):
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
four_by_four[..., :3, :4] = rigid3arr
four_by_four[..., 3, 3] = 1
return four_by_four
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:]
......@@ -189,10 +196,12 @@ class TestFeats(unittest.TestCase):
return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
out_gt["rigidgroups_gt_frames"] = convert_func(
out_gt["rigidgroups_gt_frames"]
)
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_alt_gt_frames"] = convert_func(
out_gt["rigidgroups_alt_gt_frames"]
)
......@@ -278,13 +287,21 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
if consts.is_multimer:
rots_gt = torch.as_tensor(np.array(out_gt_rot))
trans_gt = torch.as_tensor(np.array(out_gt_trans))
else:
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1
......@@ -321,9 +338,6 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions),
)
if consts.is_multimer:
xyz = xyz.to_tensor()
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed()
......
......@@ -15,6 +15,7 @@
import os
import torch
import numpy as np
from pathlib import Path
import unittest
import ml_collections as mlc
......@@ -301,7 +302,8 @@ class TestLoss(unittest.TestCase):
def test_find_structural_violations_compare(self):
def run_fsv(batch, pos, config):
cwd = os.getcwd()
os.chdir("tests/test_data")
fpath = Path(__file__).parent.resolve() / "test_data"
os.chdir(str(fpath))
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
......@@ -436,7 +438,7 @@ class TestLoss(unittest.TestCase):
"true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32
),
)
}
out_gt = f.apply({}, None, value, batch)["loss"]
......@@ -448,7 +450,9 @@ class TestLoss(unittest.TestCase):
with torch.no_grad():
out_repro = masked_msa_loss(
value["logits"],
**batch,
batch["true_msa"],
batch["bert_mask"],
consts.msa_logits
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
......@@ -903,6 +907,9 @@ class TestLoss(unittest.TestCase):
),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b)
......@@ -943,7 +950,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
@unittest.skipIf(consts.is_multimer or "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pickle
import torch
import torch.nn as nn
......@@ -54,7 +55,7 @@ class TestModel(unittest.TestCase):
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config(consts.model)
c = model_config(consts.model, train=True)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
......@@ -68,6 +69,7 @@ class TestModel(unittest.TestCase):
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
......@@ -95,11 +97,14 @@ class TestModel(unittest.TestCase):
out = model(batch)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer, "Additional changes required for multimer.")
def test_compare(self):
#TODO: Fix test data for multimer MSA features
def run_alphafold(batch):
config = compare_utils.get_alphafold_config()
model = self.am_modules.AlphaFold(config.model)
return model(
batch=batch,
is_training=False,
......@@ -110,7 +115,8 @@ class TestModel(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
fpath = Path(__file__).parent.resolve() / "test_data/sample_feats.pickle"
with open(str(fpath), "rb") as fp:
batch = pickle.load(fp)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
......@@ -150,6 +156,4 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)
print(torch.mean(torch.abs(out_gt - out_repro)))
print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
......@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
class TestMSAColumnAttention(unittest.TestCase):
......@@ -158,9 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).cpu()
print(torch.mean(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase):
......
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
if __name__ == "__main__":
......
......@@ -15,54 +15,33 @@
import torch
import unittest
from openfold.model.primitives import (
Attention
)
from openfold.model.primitives import Attention
from tests.config import consts
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
c_hidden = 32
n = 2 ** 12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, biases=bias, use_lma=True, q_chunk_size=1024, kv_chunk_size=4096)
real = a(q, k, v, biases=bias)
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -99,7 +99,7 @@ class TestStructureModule(unittest.TestCase):
z = torch.rand((batch_size, n, n, c_z))
f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
out = sm(s, z, f)
out = sm({"single": s, "pair": z}, f)
if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
......@@ -183,10 +183,13 @@ class TestStructureModule(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module(
torch.as_tensor(representations["single"]).cuda(),
torch.as_tensor(representations["pair"]).cuda(),
{
"single": torch.as_tensor(representations["single"]).cuda(),
"pair": torch.as_tensor(representations["pair"]).cuda(),
},
torch.as_tensor(batch["aatype"]).cuda(),
mask=torch.as_tensor(batch["seq_mask"]).cuda(),
inplace_safe=False,
)
out_repro = out_repro["positions"][-1].cpu()
......@@ -286,7 +289,7 @@ class TestInvariantPointAttention(unittest.TestCase):
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
torch.as_tensor(affines).float().cuda()
)
sample_affine = rigids
else:
......
......@@ -206,7 +206,7 @@ class Template(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def test_template_embedding(pair, batch, mask_2d):
def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config()
te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
......@@ -214,7 +214,7 @@ class Template(unittest.TestCase):
)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=multichain_mask_2d, is_training=False)
act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act
......@@ -228,12 +228,12 @@ class Template(unittest.TestCase):
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
batch["multichain_mask_2d"] = multichain_mask_2d
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
......@@ -242,7 +242,7 @@ class Template(unittest.TestCase):
)
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -259,7 +259,9 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=multichain_mask_2d,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
use_lma=False,
inplace_safe=False
)
else:
out_repro = model.template_embedder(
......@@ -267,7 +269,9 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size
chunk_size=consts.chunk_size,
use_lma=False,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"]
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
import numpy as np
......@@ -89,13 +90,19 @@ class TestTriangularAttention(unittest.TestCase):
if starting
else model.evoformer.blocks[0].pair_stack.tri_att_end
)
# To save memory, the full model transposes inputs outside of the
# triangle attention module. We adjust the module here.
module = copy.deepcopy(module)
module.starting = starting
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
chunk_size=None,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
c_z = consts.c_z
c = 11
outgoing = True
tm = TriangleMultiplicationOutgoing(
c_z,
c,
outgoing,
)
n_res = consts.c_z
......@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
......@@ -106,6 +105,39 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True)
def _tri_mul_inplace(self, incoming=False):
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=False,
).cpu()
# This has to come second because inference mode is in-place
out_inplace = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=2,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
def test_tri_mul_out_inference(self):
self._tri_mul_inplace()
def test_tri_mul_in_inference(self):
self._tri_mul_inplace(incoming=True)
if __name__ == "__main__":
unittest.main()
......@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
......
import argparse
import os
import logging
import random
import numpy
import torch
from openfold.config import model_config
from openfold.data import feature_pipeline
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
from openfold.np import protein
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
relax_protein
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset)
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
numpy.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
with open(args.input_fasta) as fasta_file:
tags, sequences = parse_fasta(fasta_file.read())
if len(sequences) != 1:
raise ValueError("the threading script can only process a single sequence")
query_sequence = sequences[0]
query_tag = tags[0]
feature_dict = make_sequence_features_with_custom_template(
query_sequence,
args.input_mmcif,
args.template_id,
args.chain_id,
args.kalign_binary_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
output_name = f'{query_tag}_{args.config_preset}'
for model, output_directory in model_generator:
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: numpy.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
args.subtract_plddt
)
unrelaxed_output_path = os.path.join(
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
parser.add_argument(
"--chain_id", type=str,
help="""The chain ID of the chain in the template to use"""
)
parser.add_argument(
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment