Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a80d5263
"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "baa129f71cf122a5fc0f1cffd1e9ecec7daba795"
Unverified
Commit
a80d5263
authored
Dec 29, 2022
by
shenggan
Committed by
GitHub
Dec 29, 2022
Browse files
support alphafold v2.3 param (#128)
parent
c3436dd1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
101 additions
and
40 deletions
+101
-40
fastfold/config.py
fastfold/config.py
+1
-1
fastfold/model/nn/triangular_multiplicative_update.py
fastfold/model/nn/triangular_multiplicative_update.py
+36
-10
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+49
-25
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+10
-4
inference.py
inference.py
+5
-0
No files found.
fastfold/config.py
View file @
a80d5263
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
"enabled"
:
True
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
...
fastfold/model/nn/triangular_multiplicative_update.py
View file @
a80d5263
...
@@ -22,6 +22,16 @@ import torch.nn as nn
...
@@ -22,6 +22,16 @@ import torch.nn as nn
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.utils.tensor_utils
import
permute_final_dims
from
fastfold.utils.tensor_utils
import
permute_final_dims
_FUSED_TRIANGLE_MULTIPLICATION
=
False
def
set_fused_triangle_multiplication
():
global
_FUSED_TRIANGLE_MULTIPLICATION
_FUSED_TRIANGLE_MULTIPLICATION
=
True
def
is_fused_triangle_multiplication
():
global
_FUSED_TRIANGLE_MULTIPLICATION
return
_FUSED_TRIANGLE_MULTIPLICATION
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
"""
"""
...
@@ -40,11 +50,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -40,11 +50,16 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
_outgoing
=
_outgoing
self
.
_outgoing
=
_outgoing
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
if
_FUSED_TRIANGLE_MULTIPLICATION
:
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_p
=
Linear
(
self
.
c_z
,
2
*
self
.
c_hidden
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
2
*
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_gate
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
else
:
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
linear_z
=
Linear
(
self
.
c_hidden
,
self
.
c_z
,
init
=
"final"
)
self
.
layer_norm_in
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_in
=
LayerNorm
(
self
.
c_z
)
...
@@ -77,14 +92,25 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -77,14 +92,25 @@ class TriangleMultiplicativeUpdate(nn.Module):
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
if
_FUSED_TRIANGLE_MULTIPLICATION
:
a
=
a
*
mask
a
=
self
.
linear_p
(
z
)
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
a
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
b
=
b
*
mask
a
,
b
=
a
.
chunk
(
2
,
dim
=-
1
)
else
:
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
if
_FUSED_TRIANGLE_MULTIPLICATION
:
g
=
self
.
sigmoid
(
self
.
linear_gate
(
z
))
else
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
z
=
x
*
g
return
z
return
z
...
...
fastfold/utils/import_weights.py
View file @
a80d5263
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
import
torch
import
torch
from
typing
import
Union
,
List
from
typing
import
Union
,
List
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
_NPZ_KEY_PREFIX
=
"alphafold/alphafold_iteration/"
_NPZ_KEY_PREFIX
=
"alphafold/alphafold_iteration/"
...
@@ -187,32 +188,55 @@ def get_translation_dict(model, version):
...
@@ -187,32 +188,55 @@ def get_translation_dict(model, version):
"feat_2d_weights"
:
LinearWeight
(
tri_att
.
linear
.
weight
),
"feat_2d_weights"
:
LinearWeight
(
tri_att
.
linear
.
weight
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
}
if
is_fused_triangle_multiplication
():
TriMulOutParams
=
lambda
tri_mul
:
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
TriMulOutParams
=
lambda
tri_mul
:
{
# see commit b88f8da on the Alphafold repo
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
# iterations of triangle multiplication, which is confusing and not
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
# reproduced in our implementation.
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
TriMulInParams
=
lambda
tri_mul
:
{
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
else
:
TriMulOutParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
# see commit b88f8da on the Alphafold repo
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
}
PairTransitionParams
=
lambda
pt
:
{
PairTransitionParams
=
lambda
pt
:
{
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
...
@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
...
@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
if
"template_"
in
k
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
evo_dict
.
pop
(
k
)
if
"_ptm"
in
version
:
if
"_ptm"
in
version
or
is_multimer
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
...
...
fastfold/utils/inject_fastnn.py
View file @
a80d5263
...
@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
...
@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from
fastfold.model.fastnn.embedders
import
TemplateEmbedder
from
fastfold.model.fastnn.embedders
import
TemplateEmbedder
from
fastfold.model.fastnn.embedders_multimer
import
TemplateEmbedderMultimer
from
fastfold.model.fastnn.embedders_multimer
import
TemplateEmbedderMultimer
from
fastfold.model.fastnn.ops
import
RecyclingEmbedder
,
InputEmbedder
from
fastfold.model.fastnn.ops
import
RecyclingEmbedder
,
InputEmbedder
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
def
copy_layernorm
(
model_fast
,
model_ori
):
def
copy_layernorm
(
model_fast
,
model_ori
):
...
@@ -72,13 +73,18 @@ def copy_transition(model_fast, model_ori):
...
@@ -72,13 +73,18 @@ def copy_transition(model_fast, model_ori):
def
copy_triangle
(
model_fast
,
model_ori
):
def
copy_triangle
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
if
is_fused_triangle_multiplication
():
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_gate
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
copy_linear
(
model_fast
.
left_right_projection
,
model_ori
.
linear_p
)
copy_linear
(
model_fast
.
left_right_gate
,
model_ori
.
linear_g
)
else
:
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
def
copy_triangle_att
(
model_fast
,
model_ori
):
def
copy_triangle_att
(
model_fast
,
model_ori
):
...
...
inference.py
View file @
a80d5263
...
@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
...
@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
from
fastfold.common
import
protein
,
residue_constants
from
fastfold.common
import
protein
,
residue_constants
from
fastfold.config
import
model_config
from
fastfold.config
import
model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.nn.triangular_multiplicative_update
import
set_fused_triangle_multiplication
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
,
FastFoldMultimerDataWorkFlow
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
,
FastFoldMultimerDataWorkFlow
...
@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
...
@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
if
args
.
chunk_size
:
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
chunk_size
=
args
.
chunk_size
if
"v3"
in
args
.
param_path
:
set_fused_triangle_multiplication
()
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment