Commit 92f1932e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Initial commit

parent 3d9c2de3
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.evoformer import *
class TestEvoformerStack(unittest.TestCase):
def test_shape(self):
batch_size = 5
s_t = 27
n_res = 29
c_m = 7
c_z = 11
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_pair_att = 14
c_s = 23
no_heads_msa = 3
no_heads_pair = 7
no_blocks = 2
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
inf = 1e9
eps = 1e-10
es = EvoformerStack(
c_m,
c_z,
c_hidden_msa_att,
c_hidden_opm,
c_hidden_mul,
c_hidden_pair_att,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
chunk_size=4,
inf=inf,
eps=eps,
).eval()
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_m_before = m.shape
shape_z_before = z.shape
m, z, s = es(m, z, msa_mask, pair_mask)
self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s))
class TestExtraMSAStack(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 23
n_res = 5
c_m = 7
c_z = 11
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_tri_att = 16
no_heads_msa = 3
no_heads_pair = 8
no_blocks = 2
transition_n = 5
msa_dropout = 0.15
pair_stack_dropout = 0.25
inf = 1e9
eps = 1e-10
es = ExtraMSAStack(
c_m,
c_z,
c_hidden_msa_att,
c_hidden_opm,
c_hidden_mul,
c_hidden_tri_att,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
chunk_size=4,
inf=inf,
eps=eps,
).eval()
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res,))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res,))
shape_z_before = z.shape
z = es(m, z, msa_mask, pair_mask)
self.assertTrue(z.shape == shape_z_before)
class TestMSATransition(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n_r = 5
c_m = 7
n = 11
mt = MSATransition(c_m, n, chunk_size=4)
m = torch.rand((batch_size, s_t, n_r, c_m))
shape_before = m.shape
m = mt(m)
shape_after = m.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from config import model_config
from alphafold.model.model import AlphaFold
from alphafold.model.import_weights import *
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "tests/model/alphafold_2/params_model_1.npz"
c = model_config("model_1").model
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
import_jax_weights_(
model, npz_path,
)
data = np.load(npz_path)
prefix = "alphafold/alphafold_iteration/"
test_pairs = [
# Normal linear weight
(torch.as_tensor(
data[prefix + "structure_module/initial_projection//weights"]
).transpose(-1, -2),
model.structure_module.linear_in.weight),
# Normal layer norm param
(torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias),
# From a stack
(torch.as_tensor(data[
prefix + (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,),
]
for w_alpha, w_repro in test_pairs:
self.assertTrue(torch.all(w_alpha == w_repro))
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import numpy as np
import unittest
from alphafold.utils.loss import *
from alphafold.utils.utils import T
class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self):
batch_size = 2
n = 5
a = torch.rand((batch_size, n, 7, 2))
a_gt = torch.rand((batch_size, n, 7, 2))
a_alt_gt = torch.rand((batch_size, n, 7, 2))
loss = torsion_angle_loss(a, a_gt, a_alt_gt)
def test_run_fape(self):
batch_size = 2
n_frames = 7
n_atoms = 5
x = torch.rand((batch_size, n_atoms, 3))
x_gt = torch.rand((batch_size, n_atoms, 3))
rots = torch.rand((batch_size, n_frames, 3, 3))
rots_gt = torch.rand((batch_size, n_frames, 3, 3))
trans = torch.rand((batch_size, n_frames, 3))
trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans)
t_gt = T(rots_gt, trans_gt)
loss = compute_fape(t, x, t_gt, x_gt)
def test_between_residue_bond_loss(self):
bs = 2
n = 10
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
residue_index = torch.arange(n).unsqueeze(0)
aatype = torch.randint(0, 22, (bs, n,))
between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
aatype,
)
def test_between_residue_clash_loss(self):
bs = 2
n = 10
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
atom14_atom_radius = torch.rand(bs, n, 14)
residue_index = torch.arange(n).unsqueeze(0)
loss = between_residue_clash_loss(
pred_pos,
pred_atom_mask,
atom14_atom_radius,
residue_index,
)
def test_find_structural_violations(self):
n = 10
batch = {
"atom14_atom_exists": torch.randint(0, 2, (n, 14)),
"residue_index": torch.arange(n),
"aatype": torch.randint(0, 21, (n,)),
"residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
}
pred_pos = torch.rand(n, 14, 3)
config = ml_collections.ConfigDict({
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
})
find_structural_violations(batch, pred_pos, config)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from config import *
from alphafold.model.model import *
from alphafold.utils.utils import my_tree_map
from tests.alphafold.utils.utils import (
random_template_feats,
random_extra_msa_feats,
)
class TestModel(unittest.TestCase):
def test_dry_run(self):
batch_size = 2
n_seq = 5
n_templ = 7
n_res = 11
n_extra_seq = 13
c = model_config("model_1").model
c.no_cycles = 2
c.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
model = AlphaFold(c)
batch = {}
tf = torch.randint(
c.input_embedder.tf_dim - 1, size=(batch_size, n_res)
)
batch["target_feat"] = nn.functional.one_hot(
tf, c.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand(
(batch_size, n_seq, n_res, c.input_embedder.msa_dim)
)
t_feats = random_template_feats(n_templ, n_res, batch_size=batch_size)
batch.update({k:torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(
n_extra_seq, n_res, batch_size=batch_size
)
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_seq, n_res)
)
batch["seq_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_res)
)
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
)
batch = my_tree_map(add_recycling_dims, batch, torch.Tensor)
with torch.no_grad():
out = model(batch)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.msa import *
class TestMSARowAttentionWithPairBias(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
c_z = 11
c = 52
no_heads = 4
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads)
m = torch.rand((batch_size, s_t, n, c_m))
z = torch.rand((batch_size, n, n, c_z))
shape_before = m.shape
m = mrapb(m, z)
shape_after = m.shape
self.assertTrue(shape_before == shape_after)
class TestMSAColumnAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
c = 44
no_heads = 4
msaca = MSAColumnAttention(c_m, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_m))
shape_before = x.shape
x = msaca(x)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
class TestMSAColumnGlobalAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
c = 44
no_heads = 4
msagca = MSAColumnGlobalAttention(c_m, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_m))
shape_before = x.shape
x = msagca(x)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.outer_product_mean import *
class TestOuterProductMean(unittest.TestCase):
def test_shape(self):
batch_size = 2
s = 5
n_res = 7
c_m = 11
c = 13
c_z = 17
opm = OuterProductMean(c_m, c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
mask = torch.randint(0, 2, size=(batch_size, s, n_res))
m = opm(m, mask)
self.assertTrue(m.shape == (batch_size, n_res, n_res, c_z))
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.pair_transition import *
class TestPairTransition(unittest.TestCase):
def test_shape(self):
c_z = 5
n = 4
pt = PairTransition(c_z, n)
batch_size = 4
n_res = 256
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = z.shape
z = pt(z, mask)
shape_after = z.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from alphafold.model.structure_module import *
from alphafold.model.structure_module import (
_torsion_angles_to_frames,
_frames_and_literature_positions_to_atom14_pos,
)
from alphafold.utils.utils import T
class TestStructureModule(unittest.TestCase):
def test_structure_module_shape(self):
batch_size = 2
n = 5
c_s = 7
c_z = 11
c_ipa = 13
c_resnet = 17
no_heads_ipa = 6
no_query_points = 4
no_value_points = 4
dropout_rate = 0.1
no_layers = 3
no_transition_layers = 3
no_resnet_layers = 3
ar_epsilon = 1e-6
no_angles = 7
trans_scale_factor = 10
inf = 1e5
sm = StructureModule(
c_s,
c_z,
c_ipa,
c_resnet,
no_heads_ipa,
no_query_points,
no_value_points,
dropout_rate,
no_layers,
no_transition_layers,
no_resnet_layers,
no_angles,
trans_scale_factor,
ar_epsilon,
inf,
)
s = torch.rand((batch_size, n, c_s))
z = torch.rand((batch_size, n, n, c_z))
f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
out = sm(s, z, f)
self.assertTrue(
out["transformations"].shape == (no_layers, batch_size, n, 4, 4)
)
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
self.assertTrue(
out["positions"].shape == (no_layers, batch_size, n, 14, 3)
)
def test_torsion_angles_to_frames_shape(self):
batch_size = 2
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = T(rots, trans)
angles = torch.rand((batch_size, n, 7, 2))
aas = torch.tensor([i % 2 for i in range(n)])
aas = torch.stack([aas for _ in range(batch_size)])
frames = _torsion_angles_to_frames(
ts,
angles,
aas,
torch.tensor(restype_rigid_group_default_frame),
)
self.assertTrue(frames.shape == (batch_size, n, 8))
def test_frames_and_literature_positions_to_atom14_pos_shape(self):
batch_size = 2
n = 5
rots = torch.rand((batch_size, n, 8, 3, 3))
trans = torch.rand((batch_size, n, 8, 3))
ts = T(rots, trans)
f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
xyz = _frames_and_literature_positions_to_atom14_pos(
ts,
f,
torch.tensor(restype_rigid_group_default_frame),
torch.tensor(restype_atom14_to_rigid_group),
torch.tensor(restype_atom14_mask),
torch.tensor(restype_atom14_rigid_group_positions),
)
self.assertTrue(xyz.shape == (batch_size, n, 14, 3))
def test_structure_module_transition_shape(self):
batch_size = 2
n = 5
c = 7
num_layers = 3
dropout = 0.1
smt = StructureModuleTransition(c, num_layers, dropout)
s = torch.rand((batch_size, n, c))
shape_before = s.shape
s = smt(s)
shape_after = s.shape
self.assertTrue(shape_before == shape_after)
class TestBackboneUpdate(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_res = 3
c_in = 5
bu = BackboneUpdate(c_in)
s = torch.rand((batch_size, n_res, c_in))
t = bu(s)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(tra.shape == (batch_size, n_res, 3))
class TestInvariantPointAttention(unittest.TestCase):
def test_shape(self):
c_m = 13
c_z = 17
c_hidden = 19
no_heads = 5
no_qp = 7
no_vp = 11
batch_size = 2
n_res = 23
s = torch.rand((batch_size, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.ones((batch_size, n_res))
rots = torch.rand((batch_size, n_res, 3, 3))
trans = torch.rand((batch_size, n_res, 3))
t = T(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
)
shape_before = s.shape
s = ipa(s, z, t, mask)
self.assertTrue(s.shape == shape_before)
class TestAngleResnet(unittest.TestCase):
def test_shape(self):
batch_size = 2
n = 3
c_s = 13
c_hidden = 11
no_layers = 5
no_angles = 7
epsilon = 1e-12
ar = AngleResnet(c_s, c_hidden, no_layers, no_angles, epsilon)
a = torch.rand((batch_size, n, c_s))
a_initial = torch.rand((batch_size, n, c_s))
a = ar(a, a_initial)
self.assertTrue(a.shape == (batch_size, n, no_angles, 2))
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.template import *
class TestTemplatePointwiseAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
c_t = 5
c_z = 7
c = 26
no_heads = 13
n = 17
tpa = TemplatePointwiseAttention(c_t, c_z, c, no_heads, chunk_size=4)
t = torch.rand((batch_size, s_t, n, n, c_t))
z = torch.rand((batch_size, n, n, c_z))
z_update = tpa(t, z)
self.assertTrue(z_update.shape == z.shape)
class TestTemplatePairStack(unittest.TestCase):
def test_shape(self):
batch_size = 2
c_t = 5
c_hidden_tri_att = 7
c_hidden_tri_mul = 7
no_blocks = 2
no_heads = 4
pt_inner_dim = 15
dropout = 0.25
n_templ = 3
n_res = 5
chunk_size = 4
tpe = TemplatePairStack(
c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_blocks=no_blocks,
no_heads=no_heads,
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
chunk_size=chunk_size,
)
t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
shape_before = t.shape
t = tpe(t, mask)
shape_after = t.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.triangular_attention import *
class TestTriangularAttention(unittest.TestCase):
def test_shape(self):
c_z = 2
c = 12
no_heads = 4
starting = True
tan = TriangleAttention(
c_z,
c,
no_heads,
starting
)
batch_size = 4
n_res = 7
x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
x = tan(x)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from alphafold.model.triangular_multiplicative_update import *
class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
c_z = 7
c = 11
outgoing = True
tm = TriangleMultiplicativeUpdate(
c_z,
c,
outgoing,
)
n_res = 5
batch_size = 2
x = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = x.shape
x = tm(x, mask)
shape_after = x.shape
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import unittest
from alphafold.utils.utils import *
X_90_ROT = torch.tensor([
[1, 0, 0],
[0, 0,-1],
[0, 1, 0],
])
X_NEG_90_ROT = torch.tensor([
[1, 0, 0],
[0, 0, 1],
[0,-1, 0],
])
class TestAffineT(unittest.TestCase):
def test_T_from_3_points_shape(self):
batch_size = 2
n_res = 5
x1 = torch.rand((batch_size, n_res, 3))
x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3))
t = T.from_3_points(x1, x2, x3)
rot, tra = t.rots, t.trans
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2))
def test_T_from_4x4(self):
batch_size = 2
transf = [
[1, 0, 0, 1],
[0, 0,-1, 2],
[0, 1, 0, 3],
[0, 0, 0, 1],
]
transf = torch.tensor(transf)
true_rot = transf[:3, :3]
true_trans = transf[:3, 3]
transf = torch.stack(
[transf for _ in range(batch_size)],
dim=0
)
t = T.from_4x4(transf)
rot, tra = t.rots, t.trans
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_T_shape(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
def test_T_concat(self):
batch_size = 2
n = 5
transf = T(
torch.rand((batch_size, n, 3, 3)),
torch.rand((batch_size, n, 3))
)
transf_concat = T.concat([transf, transf], dim=0)
self.assertTrue(transf_concat.rots.shape == (batch_size * 2, n, 3, 3))
transf_concat = T.concat([transf, transf], dim=1)
self.assertTrue(transf_concat.rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_concat.rots[:, :n] == transf.rots))
self.assertTrue(torch.all(transf_concat.trans[:, :n] == transf.trans))
def test_T_compose(self):
trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1]
t1 = T(X_90_ROT, torch.tensor(trans_1))
t2 = T(X_NEG_90_ROT, torch.tensor(trans_2))
t3 = t1.compose(t2)
self.assertTrue(torch.all(t3.rots == torch.eye(3)))
self.assertTrue(torch.all(t3.trans == 0))
def test_T_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0)
t = T(rots, trans)
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
x = x.view(2, -1, 3) # [2, 10, 3]
pts = t[..., None].apply(x)
# All simple consequences of the two x-axis rotations
self.assertTrue(torch.all(pts[..., 0] == x[..., 0] + 1))
self.assertTrue(torch.all(pts[0, :, 1] == x[0, :, 2] * -1 + 1))
self.assertTrue(torch.all(pts[1, :, 1] == x[1, :, 2] + 1))
self.assertTrue(torch.all(pts[0, :, 2] == x[0, :, 1] + 1))
self.assertTrue(torch.all(pts[1, :, 2] == x[1, :, 1] * -1 + 1))
def test_quat_to_rot(self):
forty_five = math.pi / 4
quat = torch.tensor([math.cos(forty_five), math.sin(forty_five), 0, 0])
rot = quat_to_rot(quat)
eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked == unchunked))
def test_chunk_layer_dict(self):
class LinearDictLayer(torch.nn.Linear):
def forward(self, input):
out = super().forward(input)
return {"out": out, "inner": {"out": out + 1}}
x = torch.rand(2, 4, 5, 15)
l = LinearDictLayer(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
)
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def random_template_feats(n_templ, n, batch_size=None):
b = []
if(batch_size is not None):
b.append(batch_size)
batch = {
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
"template_all_atom_masks": np.random.randint(
0, 2, (*b, n_templ, n, 37)
),
"template_all_atom_positions": np.random.rand(
*b, n_templ, n, 37, 3
) * 10,
}
batch = {k:v.astype(np.float32) for k,v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
return batch
def random_extra_msa_feats(n_extra, n, batch_size=None):
b = []
if(batch_size is not None):
b.append(batch_size)
batch = {
"extra_msa":
np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64),
"extra_has_deletion":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
"extra_deletion_value":
np.random.rand(*b, n_extra, n).astype(np.float32),
"extra_msa_mask":
np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
}
return batch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment