Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
df6b97f2
"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "ccf3bea4fd8d817c5c19561be945911642ab4972"
Commit
df6b97f2
authored
Sep 20, 2021
by
Gustaf Ahdritz
Browse files
First draft of loss class
parent
15895ea9
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
360 additions
and
86 deletions
+360
-86
alphafold/model/structure_module.py
alphafold/model/structure_module.py
+7
-3
alphafold/utils/loss.py
alphafold/utils/loss.py
+302
-82
config.py
config.py
+50
-0
tests/test_structure_module.py
tests/test_structure_module.py
+1
-1
No files found.
alphafold/model/structure_module.py
View file @
df6b97f2
...
@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
...
@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2]
# [*, no_angles * 2]
s
=
self
.
linear_out
(
s
)
s
=
self
.
linear_out
(
s
)
unnormalized_s
=
s
# [*, no_angles, 2]
# [*, no_angles, 2]
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
norm_denom
=
torch
.
sqrt
(
norm_denom
=
torch
.
sqrt
(
...
@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
...
@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
)
)
s
=
s
/
norm_denom
s
=
s
/
norm_denom
return
s
return
unnormalized_s
,
s
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
...
@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
...
@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
t
=
t
.
compose
(
self
.
bb_update
(
s
))
t
=
t
.
compose
(
self
.
bb_update
(
s
))
# [*, N, 7, 2]
# [*, N, 7, 2]
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
...
@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
...
@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
)
)
preds
=
{
preds
=
{
"
t
ra
nsformation
s"
:
"
f
ra
me
s"
:
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
,
"unnormalized_angles"
:
unnormalized_a
,
"angles"
:
a
,
"angles"
:
a
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
,
}
}
...
...
alphafold/utils/loss.py
View file @
df6b97f2
This diff is collapsed.
Click to expand it.
config.py
View file @
df6b97f2
...
@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
...
@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
"max_outer_iterations"
:
20
,
"max_outer_iterations"
:
20
,
"exclude_residues"
:
[],
"exclude_residues"
:
[],
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
1e-6
,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
1e-8
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.
,
},
"fape"
:
{
"backbone"
:
{
"clamp_distance"
:
10.
,
"loss_unit_distance"
:
10.
,
"weight"
:
0.5
,
}
"sidechain"
:
{
"clamp_distance"
:
10.
,
"length_scale"
:
10.
,
"weight"
:
0.5
,
}
"weight"
:
1.0
,
},
"lddt"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.
,
"num_bins"
:
50
,
"eps"
:
1e-10
,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"eps"
:
1e-8
,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
1e-6
,
"weight"
:
1.0
,
},
"violation"
:
{
"eps"
:
1e-6
,
"weight"
:
0.
,
},
},
})
})
tests/test_structure_module.py
View file @
df6b97f2
...
@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
...
@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a
=
ar
(
a
,
a_initial
)
_
,
a
=
ar
(
a
,
a_initial
)
self
.
assertTrue
(
a
.
shape
==
(
batch_size
,
n
,
no_angles
,
2
))
self
.
assertTrue
(
a
.
shape
==
(
batch_size
,
n
,
no_angles
,
2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment