Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a80d5263
Unverified
Commit
a80d5263
authored
Dec 29, 2022
by
shenggan
Committed by
GitHub
Dec 29, 2022
Browse files
support alphafold v2.3 param (#128)
parent
c3436dd1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
101 additions
and
40 deletions
+101
-40
fastfold/config.py
fastfold/config.py
+1
-1
fastfold/model/nn/triangular_multiplicative_update.py
fastfold/model/nn/triangular_multiplicative_update.py
+36
-10
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+49
-25
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+10
-4
inference.py
inference.py
+5
-0
No files found.
fastfold/config.py
View file @
a80d5263
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
"enabled"
:
True
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
...
fastfold/model/nn/triangular_multiplicative_update.py
View file @
a80d5263
...
@@ -22,6 +22,16 @@ import torch.nn as nn
...
@@ -22,6 +22,16 @@ import torch.nn as nn
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.utils.tensor_utils
import
permute_final_dims
from
fastfold.utils.tensor_utils
import
permute_final_dims
_FUSED_TRIANGLE_MULTIPLICATION
=
False
def
set_fused_triangle_multiplication
():
global
_FUSED_TRIANGLE_MULTIPLICATION
_FUSED_TRIANGLE_MULTIPLICATION
=
True
def
is_fused_triangle_multiplication
():
global
_FUSED_TRIANGLE_MULTIPLICATION
return
_FUSED_TRIANGLE_MULTIPLICATION
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
"""
"""
...
@@ -40,6 +50,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -40,6 +50,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
_outgoing
=
_outgoing
self
.
_outgoing
=
_outgoing
if
_FUSED_TRIANGLE_MULTIPLICATION
:
self
.
linear_p
=
Linear
(
self
.
c_z
,
2
*
self
.
c_hidden
)
self
.
linear_g
=
Linear
(
self
.
c_z
,
2
*
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_gate
=
Linear
(
self
.
c_z
,
self
.
c_z
,
init
=
"gating"
)
else
:
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_a_g
=
Linear
(
self
.
c_z
,
self
.
c_hidden
,
init
=
"gating"
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
self
.
linear_b_p
=
Linear
(
self
.
c_z
,
self
.
c_hidden
)
...
@@ -77,13 +92,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -77,13 +92,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
z
=
self
.
layer_norm_in
(
z
)
if
_FUSED_TRIANGLE_MULTIPLICATION
:
a
=
self
.
linear_p
(
z
)
*
mask
a
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
a
,
b
=
a
.
chunk
(
2
,
dim
=-
1
)
else
:
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
mask
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
b
=
b
*
mask
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
if
_FUSED_TRIANGLE_MULTIPLICATION
:
g
=
self
.
sigmoid
(
self
.
linear_gate
(
z
))
else
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
z
=
x
*
g
...
...
fastfold/utils/import_weights.py
View file @
a80d5263
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
import
torch
import
torch
from
typing
import
Union
,
List
from
typing
import
Union
,
List
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
_NPZ_KEY_PREFIX
=
"alphafold/alphafold_iteration/"
_NPZ_KEY_PREFIX
=
"alphafold/alphafold_iteration/"
...
@@ -188,6 +189,29 @@ def get_translation_dict(model, version):
...
@@ -188,6 +189,29 @@ def get_translation_dict(model, version):
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
}
if
is_fused_triangle_multiplication
():
TriMulOutParams
=
lambda
tri_mul
:
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
else
:
TriMulOutParams
=
lambda
tri_mul
:
{
TriMulOutParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
...
@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
...
@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
if
"template_"
in
k
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
evo_dict
.
pop
(
k
)
if
"_ptm"
in
version
:
if
"_ptm"
in
version
or
is_multimer
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
...
...
fastfold/utils/inject_fastnn.py
View file @
a80d5263
...
@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
...
@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from
fastfold.model.fastnn.embedders
import
TemplateEmbedder
from
fastfold.model.fastnn.embedders
import
TemplateEmbedder
from
fastfold.model.fastnn.embedders_multimer
import
TemplateEmbedderMultimer
from
fastfold.model.fastnn.embedders_multimer
import
TemplateEmbedderMultimer
from
fastfold.model.fastnn.ops
import
RecyclingEmbedder
,
InputEmbedder
from
fastfold.model.fastnn.ops
import
RecyclingEmbedder
,
InputEmbedder
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
def
copy_layernorm
(
model_fast
,
model_ori
):
def
copy_layernorm
(
model_fast
,
model_ori
):
...
@@ -72,12 +73,17 @@ def copy_transition(model_fast, model_ori):
...
@@ -72,12 +73,17 @@ def copy_transition(model_fast, model_ori):
def
copy_triangle
(
model_fast
,
model_ori
):
def
copy_triangle
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
if
is_fused_triangle_multiplication
():
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_gate
)
copy_linear
(
model_fast
.
left_right_projection
,
model_ori
.
linear_p
)
copy_linear
(
model_fast
.
left_right_gate
,
model_ori
.
linear_g
)
else
:
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
...
...
inference.py
View file @
a80d5263
...
@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
...
@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
from
fastfold.common
import
protein
,
residue_constants
from
fastfold.common
import
protein
,
residue_constants
from
fastfold.config
import
model_config
from
fastfold.config
import
model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.nn.triangular_multiplicative_update
import
set_fused_triangle_multiplication
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.data.tools
import
hhsearch
,
hmmsearch
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
,
FastFoldMultimerDataWorkFlow
from
fastfold.workflow.template
import
FastFoldDataWorkFlow
,
FastFoldMultimerDataWorkFlow
...
@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
...
@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
if
args
.
chunk_size
:
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
chunk_size
=
args
.
chunk_size
if
"v3"
in
args
.
param_path
:
set_fused_triangle_multiplication
()
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment