Commit df6b97f2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

First draft of loss class

parent 15895ea9
...@@ -129,6 +129,8 @@ class AngleResnet(nn.Module): ...@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2] # [*, no_angles * 2]
s = self.linear_out(s) s = self.linear_out(s)
unnormalized_s = s
# [*, no_angles, 2] # [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2) s = s.view(*s.shape[:-1], -1, 2)
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
...@@ -139,7 +141,7 @@ class AngleResnet(nn.Module): ...@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
) )
s = s / norm_denom s = s / norm_denom
return s return unnormalized_s, s
class InvariantPointAttention(nn.Module): class InvariantPointAttention(nn.Module):
...@@ -723,7 +725,7 @@ class StructureModule(nn.Module): ...@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
t = t.compose(self.bb_update(s)) t = t.compose(self.bb_update(s))
# [*, N, 7, 2] # [*, N, 7, 2]
a = self.angle_resnet(s, s_initial) unnormalized_a, a = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames( all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), a, f, t.scale_translation(self.trans_scale_factor), a, f,
...@@ -735,8 +737,10 @@ class StructureModule(nn.Module): ...@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
) )
preds = { preds = {
"transformations": "frames":
t.scale_translation(self.trans_scale_factor).to_4x4(), t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global,
"unnormalized_angles": unnormalized_a,
"angles": a, "angles": a,
"positions": pred_xyz, "positions": pred_xyz,
} }
......
This diff is collapsed.
...@@ -180,4 +180,54 @@ config = mlc.ConfigDict({ ...@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
"max_outer_iterations": 20, "max_outer_iterations": 20,
"exclude_residues": [], "exclude_residues": [],
}, },
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.,
},
"fape": {
"backbone": {
"clamp_distance": 10.,
"loss_unit_distance": 10.,
"weight": 0.5,
}
"sidechain": {
"clamp_distance": 10.,
"length_scale": 10.,
"weight": 0.5,
}
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.,
"num_bins": 50,
"eps": 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": 1e-6,
"weight": 1.0,
},
"violation": {
"eps": 1e-6,
"weight": 0.,
},
},
}) })
...@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase): ...@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
a = torch.rand((batch_size, n, c_s)) a = torch.rand((batch_size, n, c_s))
a_initial = torch.rand((batch_size, n, c_s)) a_initial = torch.rand((batch_size, n, c_s))
a = ar(a, a_initial) _, a = ar(a, a_initial)
self.assertTrue(a.shape == (batch_size, n, no_angles, 2)) self.assertTrue(a.shape == (batch_size, n, no_angles, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment