"lib/llm/src/request_template.rs" did not exist on "c9130f8f8ce264379131e9ee2973534fe4cbf713"
Commit 9f6b67f3 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix another FP16 overflow, finish tests, update util scripts

parent 893fe372
......@@ -123,7 +123,7 @@ config = mlc.ConfigDict({
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e9,
"inf": 1e5,#1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
......@@ -133,11 +133,11 @@ config = mlc.ConfigDict({
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e9,
"inf": 1e5,#1e9,
},
"inf": 1e9,
"inf": 1e5,#1e9,
"eps": eps,#1e-6,
"enabled": False,#True,
"enabled": True,
"embed_angles": True,
},
"extra_msa": {
......@@ -160,7 +160,7 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e9,
"inf": 1e5,#1e9,
"eps": eps,#1e-10,
},
"enabled": True,
......@@ -181,7 +181,7 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e9,
"inf": 1e5,#1e9,
"eps": eps,#1e-10,
},
"structure_module": {
......
......@@ -192,7 +192,6 @@ class TemplatePairStackBlock(nn.Module):
return z
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
......@@ -273,7 +272,7 @@ class TemplatePairStack(nn.Module):
_mask_trans=_mask_trans,
) for b in self.blocks
],
args=(t),
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
......
......@@ -115,8 +115,8 @@ def compute_residx(batch):
restype_atom37_to_atom14 = aatype.new_tensor(
restype_atom37_to_atom14
)
restype_atom14_mask = aatype.new_tensor(
restype_atom14_mask, dtype=float_type
restype_atom14_mask = batch["seq_mask"].new_tensor(
restype_atom14_mask
)
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
......@@ -527,13 +527,17 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
)
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
# TODO: Consider running this in double precision
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(
eps + torch.sum(affine_vec ** 2, dim=-1)
)
......
......@@ -407,7 +407,15 @@ def import_jax_weights_(model, npz_path, version="model_1"):
},
}
if(version not in ["model_1", "model_2"]):
no_templ = [
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_5_ptm",
]
if(version in no_templ):
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
......
......@@ -1428,10 +1428,10 @@ class AlphaFoldLoss(nn.Module):
for k,loss_fn in loss_fns.items():
weight = self.config[k].weight
if(weight):
print(k)
#print(k)
loss = loss_fn()
print(weight * loss)
#print(weight * loss)
cum_loss = cum_loss + weight * loss
print(cum_loss)
#print(cum_loss)
return cum_loss
......@@ -31,6 +31,10 @@ pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \
wget -q -P openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
# Certain tests need access to this file
mkdir -p tests/test_data/alphafold/common
ln -s openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
# Download pretrained openfold weights
scripts/download_alphafold_params.sh openfold/resources
......
#!/bin/bash
FLAGS=""
while getopts ":v" option; do
case $option in
v)
FLAGS=$(echo "-v $FLAGS" | xargs) # strip whitespace
;;
*)
echo "Invalid option: ${option}"
;;
esac
done
python3 -m unittest $FLAGS "$@" || \
python3 -m unittest "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
......@@ -64,7 +64,7 @@ def get_global_pretrained_openfold():
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_(_model, _param_path)
import_jax_weights_(_model, _param_path, version="model_1_ptm")
_model = _model.cuda()
return _model
......
......@@ -15,7 +15,12 @@
import torch
import numpy as np
import unittest
from alphafold.model.embedders import *
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
class TestInputEmbedder(unittest.TestCase):
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask):
return alphafold.model.modules.pseudo_beta_fn(
aatype,
all_atom_pos,
all_atom_mask,
)
f = hk.transform(test_pbf)
n_res = consts.n_res
aatype = np.random.randint(0, 22, (n_res,))
all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
all_atom_mask = np.random.randint(0, 2, (n_res, 37))
out_gt_pos, out_gt_mask = f.apply(
{}, None, aatype, all_atom_pos, all_atom_mask
)
out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
torch.tensor(aatype).cuda(),
torch.tensor(all_atom_pos).cuda(),
torch.tensor(all_atom_mask).cuda(),
)
out_repro_pos = out_repro_pos.cpu()
out_repro_mask = out_repro_mask.cpu()
self.assertTrue(
torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
)
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_torsion_angles_compare(self):
def run_test(aatype, all_atom_pos, all_atom_mask):
return alphafold.model.all_atom.atom37_to_torsion_angles(
aatype,
all_atom_pos,
all_atom_mask,
placeholder_for_undefined=False,
)
f = hk.transform(run_test)
n_templ = 7
n_res = 13
aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
all_atom_mask = np.random.randint(
0, 2, (n_templ, n_res, 37)
).astype(np.float32)
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
out_repro = feats.atom37_to_torsion_angles(
torch.as_tensor(aatype).cuda(),
torch.as_tensor(all_atom_pos).cuda(),
torch.as_tensor(all_atom_mask).cuda(),
)
tasc = out_repro["torsion_angles_sin_cos"].cpu()
atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
tam = out_repro["torsion_angles_mask"].cpu()
# This function is extremely sensitive to floating point imprecisions,
# so it is given much greater latitude in comparison tests.
self.assertTrue(
torch.mean(
torch.abs(out_gt["torsion_angles_sin_cos"] - tasc)
) < 0.01
)
self.assertTrue(
torch.mean(
torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc)
) < 0.01
)
self.assertTrue(torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask
)
f = hk.transform(run_atom37_to_frames)
n_res = consts.n_res
batch = {
"aatype": np.random.randint(0, 21, (n_res,)),
"all_atom_positions":
np.random.rand(n_res, 37, 3).astype(np.float32),
"all_atom_mask":
np.random.randint(0, 2, (n_res, 37)).astype(np.float32),
}
out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k:to_tensor(v) for k,v in out_gt.items()}
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:]
four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
four_by_four[..., :3, :3] = rot
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_gt_frames"]
)
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_alt_gt_frames"]
)
to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = feats.atom37_to_frames(eps=1e-8, **batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
def test_torsion_angles_to_frames_shape(self):
batch_size = 2
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans)
angles = torch.rand((batch_size, n, 7, 2))
aas = torch.tensor([i % 2 for i in range(n)])
aas = torch.stack([aas for _ in range(batch_size)])
frames = feats.torsion_angles_to_frames(
ts,
angles,
aas,
torch.tensor(restype_rigid_group_default_frame),
)
self.assertTrue(frames.shape == (batch_size, n, 8))
@compare_utils.skip_unless_alphafold_installed()
def test_torsion_angles_to_frames_compare(self):
def run_torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos,
)
f = hk.transform(run_torsion_angles_to_frames)
n_res = consts.n_res
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
out_gt = f.apply(
{}, None, aatype, rigids, torsion_angles_sin_cos
)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out = feats.torsion_angles_to_frames(
transformations.cuda(),
torch.as_tensor(torsion_angles_sin_cos).cuda(),
torch.as_tensor(aatype).cuda(),
torch.tensor(restype_rigid_group_default_frame).cuda(),
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot)
)
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
transforms_repro = out.to_4x4().cpu()
self.assertTrue(
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
)
def test_frames_and_literature_positions_to_atom14_pos_shape(self):
batch_size = consts.batch_size
n_res = consts.n_res
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
ts = T(rots, trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
xyz = feats.frames_and_literature_positions_to_atom14_pos(
ts,
f,
torch.tensor(restype_rigid_group_default_frame),
torch.tensor(restype_atom14_to_rigid_group),
torch.tensor(restype_atom14_mask),
torch.tensor(restype_atom14_rigid_group_positions),
)
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines):
am = alphafold.model
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines
)
f = hk.transform(run_f)
n_res = consts.n_res
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = T.from_4x4(torch.as_tensor(affines).float())
out_gt = f.apply(
{}, None, aatype, rigids
)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
out_repro = feats.frames_and_literature_positions_to_atom14_pos(
transformations.cuda(),
torch.as_tensor(aatype).cuda(),
torch.tensor(restype_rigid_group_default_frame).cuda(),
torch.tensor(restype_atom14_to_rigid_group).cuda(),
torch.tensor(restype_atom14_mask).cuda(),
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......@@ -17,19 +17,17 @@ import numpy as np
import unittest
from config import model_config
from alphafold.model.model import AlphaFold
from alphafold.model.import_weights import *
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "tests/model/alphafold_2/params_model_1.npz"
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
c = model_config("model_1").model
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
c = model_config("model_1_ptm")
c.globals.blocks_per_ckpt = None
model = AlphaFold(c.model)
import_jax_weights_(
model, npz_path,
......
This diff is collapsed.
......@@ -92,4 +92,3 @@ class TestOuterProductMean(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
......@@ -21,13 +21,17 @@ from openfold.np.residue_constants import (
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask,
)
from openfold.model.structure_module import *
from openfold.model.structure_module import (
_torsion_angles_to_frames,
_frames_and_literature_positions_to_atom14_pos,
StructureModule,
StructureModuleTransition,
BackboneUpdate,
AngleResnet,
InvariantPointAttention,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -42,10 +46,10 @@ if(compare_utils.alphafold_is_installed()):
class TestStructureModule(unittest.TestCase):
def test_structure_module_shape(self):
batch_size = 2
n = 5
c_s = 7
c_z = 11
batch_size = consts.batch_size
n = consts.n_res
c_s = consts.c_s
c_z = consts.c_z
c_ipa = 13
c_resnet = 17
no_heads_ipa = 6
......@@ -94,47 +98,6 @@ class TestStructureModule(unittest.TestCase):
out["positions"].shape == (no_layers, batch_size, n, 14, 3)
)
def test_torsion_angles_to_frames_shape(self):
batch_size = 2
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans)
angles = torch.rand((batch_size, n, 7, 2))
aas = torch.tensor([i % 2 for i in range(n)])
aas = torch.stack([aas for _ in range(batch_size)])
frames = _torsion_angles_to_frames(
ts,
angles,
aas,
torch.tensor(restype_rigid_group_default_frame),
)
self.assertTrue(frames.shape == (batch_size, n, 8))
def test_frames_and_literature_positions_to_atom14_pos_shape(self):
batch_size = 2
n = 5
rots = torch.rand((batch_size, n, 8, 3, 3))
trans = torch.rand((batch_size, n, 8, 3))
ts = T(rots, trans)
f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
xyz = _frames_and_literature_positions_to_atom14_pos(
ts,
f,
torch.tensor(restype_rigid_group_default_frame),
torch.tensor(restype_atom14_to_rigid_group),
torch.tensor(restype_atom14_mask),
torch.tensor(restype_atom14_rigid_group_positions),
)
self.assertTrue(xyz.shape == (batch_size, n, 14, 3))
def test_structure_module_transition_shape(self):
batch_size = 2
n = 5
......@@ -152,6 +115,76 @@ class TestStructureModule(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_structure_module_compare(self):
config = compare_utils.get_alphafold_config()
c_sm = config.model.heads.structure_module
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
representations = {
k:jax.lax.stop_gradient(v) for k,v in representations.items()
}
batch = {
k:jax.lax.stop_gradient(v) for k,v in batch.items()
}
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
n_res = 200
representations = {
'single': np.random.rand(n_res, consts.c_s).astype(np.float32),
'pair':
np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
batch = {
'seq_mask': np.random.randint(0, 2, (n_res,)),
'aatype': np.random.randint(0, 21, (n_res,)),
}
batch['atom14_atom_exists'] = np.take(
restype_atom14_mask,
batch['aatype'],
axis=0
)
batch['atom37_atom_exists'] = np.take(
restype_atom37_mask,
batch['aatype'],
axis=0
)
batch.update(feats.compute_residx_np(batch))
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module"
)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, representations, batch
)
out_gt = torch.as_tensor(
np.array(out_gt["final_atom14_positions"].block_until_ready())
)
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module(
torch.as_tensor(representations["single"]).cuda(),
torch.as_tensor(representations["pair"]).cuda(),
torch.as_tensor(batch["aatype"]).cuda(),
mask=torch.as_tensor(batch["seq_mask"]).cuda(),
)
out_repro = out_repro["positions"][-1].cpu()
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.01)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
......
......@@ -137,7 +137,8 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False,
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
class Template(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
......
......@@ -100,9 +100,11 @@ class TestTriangularAttention(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
self._tri_att_compare()
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_start_compare(self):
self._tri_att_compare(starting=True)
......
......@@ -16,8 +16,8 @@ import math
import torch
import unittest
from openfold.utils.affine_utils import *
from openfold.utils.tensor_utils import *
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import chunk_layer
X_90_ROT = torch.tensor([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment