Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
9c0e7519
"vscode:/vscode.git/clone" did not exist on "9bafef34bde99a6184c4d4c5af8bd434244a5ed9"
Unverified
Commit
9c0e7519
authored
Sep 05, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Sep 05, 2022
Browse files
Multimer (#57)
* add import _weight * add struct mudule
parent
ea7a6584
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
668 additions
and
271 deletions
+668
-271
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+67
-47
fastfold/model/nn/embedders.py
fastfold/model/nn/embedders.py
+103
-1
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+10
-2
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+146
-35
fastfold/utils/feats.py
fastfold/utils/feats.py
+86
-65
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+253
-118
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+2
-2
inference.py
inference.py
+1
-1
No files found.
fastfold/model/hub/alphafold.py
View file @
9c0e7519
...
@@ -17,6 +17,7 @@ from functools import partial
...
@@ -17,6 +17,7 @@ from functools import partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fastfold.data
import
data_transforms_multimer
from
fastfold.utils.feats
import
(
from
fastfold.utils.feats
import
(
pseudo_beta_fn
,
pseudo_beta_fn
,
build_extra_msa_feat
,
build_extra_msa_feat
,
...
@@ -27,8 +28,7 @@ from fastfold.utils.feats import (
...
@@ -27,8 +28,7 @@ from fastfold.utils.feats import (
from
fastfold.model.nn.embedders
import
(
from
fastfold.model.nn.embedders
import
(
InputEmbedder
,
InputEmbedder
,
RecyclingEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplateEmbedder
,
TemplatePairEmbedder
,
ExtraMSAEmbedder
,
ExtraMSAEmbedder
,
)
)
from
fastfold.model.nn.embedders_multimer
import
TemplateEmbedderMultimer
,
InputEmbedderMultimer
from
fastfold.model.nn.embedders_multimer
import
TemplateEmbedderMultimer
,
InputEmbedderMultimer
...
@@ -36,10 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
...
@@ -36,10 +36,6 @@ from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from
fastfold.model.nn.heads
import
AuxiliaryHeads
from
fastfold.model.nn.heads
import
AuxiliaryHeads
import
fastfold.common.residue_constants
as
residue_constants
import
fastfold.common.residue_constants
as
residue_constants
from
fastfold.model.nn.structure_module
import
StructureModule
from
fastfold.model.nn.structure_module
import
StructureModule
from
fastfold.model.nn.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
fastfold.model.loss
import
(
from
fastfold.model.loss
import
(
compute_plddt
,
compute_plddt
,
)
)
...
@@ -81,24 +77,13 @@ class AlphaFold(nn.Module):
...
@@ -81,24 +77,13 @@ class AlphaFold(nn.Module):
self
.
input_embedder
=
InputEmbedder
(
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
config
[
"input_embedder"
],
)
)
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
self
.
template_embedder
=
TemplateEmbedder
(
**
template_config
[
"template_angle_embedder"
],
template_config
,
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
template_config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
template_config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
template_config
[
"template_pointwise_attention"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
)
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
**
extra_msa_config
[
"extra_msa_embedder"
],
)
)
...
@@ -210,10 +195,14 @@ class AlphaFold(nn.Module):
...
@@ -210,10 +195,14 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m]
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
m
,
z
=
(
feats
[
"target_feat"
],
self
.
input_embedder
(
feats
[
"residue_index"
],
feats
[
"target_feat"
],
feats
[
"msa_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
)
if
not
self
.
globals
.
is_multimer
else
self
.
input_embedder
(
feats
)
)
)
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
...
@@ -236,9 +225,8 @@ class AlphaFold(nn.Module):
...
@@ -236,9 +225,8 @@ class AlphaFold(nn.Module):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
x_prev
=
pseudo_beta_fn
(
x_prev
,
_
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
feats
[
"aatype"
],
x_prev
,
None
x_prev
=
x_prev
.
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
...
@@ -270,40 +258,72 @@ class AlphaFold(nn.Module):
...
@@ -270,40 +258,72 @@ class AlphaFold(nn.Module):
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
if
self
.
globals
.
is_multimer
:
z
,
asym_id
=
feats
[
"asym_id"
]
pair_mask
.
to
(
dtype
=
z
.
dtype
),
multichain_mask_2d
=
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
no_batch_dims
,
template_embeds
=
self
.
template_embedder
(
)
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
if
(
self
.
config
.
template
.
embed_angles
or
(
self
.
globals
.
is_multimer
and
self
.
config
.
template
.
enabled
)
):
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_
a
ngle_embedding"
]],
[
m
,
template_embeds
[
"template_
si
ngle_embedding"
]],
dim
=-
3
dim
=-
3
)
)
# [*, S, N]
# [*, S, N]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
if
(
not
self
.
globals
.
is_multimer
):
msa_mask
=
torch
.
cat
(
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
msa_mask
=
torch
.
cat
(
dim
=-
2
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
)
dim
=-
2
)
else
:
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
)
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
if
(
self
.
globals
.
is_multimer
):
extra_msa_fn
=
data_transforms_multimer
.
build_extra_msa_feat
else
:
extra_msa_fn
=
build_extra_msa_feat
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
a
,
extra_msa_feat
,
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
...
@@ -353,14 +373,14 @@ class AlphaFold(nn.Module):
...
@@ -353,14 +373,14 @@ class AlphaFold(nn.Module):
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
(
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
)
)
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
evoformer
.
blocks_per_ckpt
=
(
...
...
fastfold/model/nn/embedders.py
View file @
9c0e7519
...
@@ -17,9 +17,20 @@ import torch
...
@@ -17,9 +17,20 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
Dict
from
typing
import
Tuple
,
Dict
from
functools
import
partial
from
fastfold.utils
import
all_atom_multimer
from
fastfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.utils.tensor_utils
import
one_hot
from
fastfold.utils.tensor_utils
import
one_hot
from
fastfold.model.nn.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
fastfold.utils
import
geometry
from
fastfold.utils.tensor_utils
import
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedder
(
nn
.
Module
):
class
InputEmbedder
(
nn
.
Module
):
"""
"""
...
@@ -221,6 +232,97 @@ class RecyclingEmbedder(nn.Module):
...
@@ -221,6 +232,97 @@ class RecyclingEmbedder(nn.Module):
return
m_update
,
z_update
return
m_update
,
z_update
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_single_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
TemplateAngleEmbedder
(
nn
.
Module
):
class
TemplateAngleEmbedder
(
nn
.
Module
):
"""
"""
...
...
fastfold/model/nn/evoformer.py
View file @
9c0e7519
...
@@ -84,7 +84,6 @@ class MSATransition(nn.Module):
...
@@ -84,7 +84,6 @@ class MSATransition(nn.Module):
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
@@ -101,10 +100,12 @@ class MSATransition(nn.Module):
...
@@ -101,10 +100,12 @@ class MSATransition(nn.Module):
m:
m:
[*, N_seq, N_res, C_m] MSA activation update
[*, N_seq, N_res, C_m] MSA activation update
"""
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if
mask
is
None
:
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
...
@@ -132,9 +133,10 @@ class EvoformerBlockCore(nn.Module):
...
@@ -132,9 +133,10 @@ class EvoformerBlockCore(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
is_multimer
:
bool
=
False
,
):
):
super
(
EvoformerBlockCore
,
self
).
__init__
()
super
(
EvoformerBlockCore
,
self
).
__init__
()
self
.
is_multimer
=
is_multimer
self
.
msa_transition
=
MSATransition
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
n
=
transition_n
,
n
=
transition_n
,
...
@@ -260,6 +262,12 @@ class EvoformerBlock(nn.Module):
...
@@ -260,6 +262,12 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
)
)
self
.
outer_product_mean
=
OuterProductMean
(
c_m
,
c_z
,
c_hidden_opm
,
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
...
fastfold/model/nn/structure_module.py
View file @
9c0e7519
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
fastfold.common.residue_constants
import
(
from
fastfold.common.residue_constants
import
(
...
@@ -73,7 +73,9 @@ class AngleResnet(nn.Module):
...
@@ -73,7 +73,9 @@ class AngleResnet(nn.Module):
Implements Algorithm 20, lines 11-14
Implements Algorithm 20, lines 11-14
"""
"""
def
__init__
(
self
,
c_in
,
c_hidden
,
no_blocks
,
no_angles
,
epsilon
):
def
__init__
(
self
,
c_in
:
int
,
c_hidden
:
int
,
no_blocks
:
int
,
no_angles
:
int
,
epsilon
:
float
):
"""
"""
Args:
Args:
c_in:
c_in:
...
@@ -145,7 +147,7 @@ class AngleResnet(nn.Module):
...
@@ -145,7 +147,7 @@ class AngleResnet(nn.Module):
unnormalized_s
=
s
unnormalized_s
=
s
norm_denom
=
torch
.
sqrt
(
norm_denom
=
torch
.
sqrt
(
torch
.
clamp
(
torch
.
clamp
(
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
=
True
),
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
=
True
),
min
=
self
.
eps
,
min
=
self
.
eps
,
)
)
)
)
...
@@ -153,6 +155,7 @@ class AngleResnet(nn.Module):
...
@@ -153,6 +155,7 @@ class AngleResnet(nn.Module):
return
unnormalized_s
,
s
return
unnormalized_s
,
s
class
PointProjection
(
nn
.
Module
):
class
PointProjection
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -491,7 +494,7 @@ class BackboneUpdate(nn.Module):
...
@@ -491,7 +494,7 @@ class BackboneUpdate(nn.Module):
Implements part of Algorithm 23.
Implements part of Algorithm 23.
"""
"""
def
__init__
(
self
,
c_s
):
def
__init__
(
self
,
c_s
:
int
):
"""
"""
Args:
Args:
c_s:
c_s:
...
@@ -517,7 +520,7 @@ class BackboneUpdate(nn.Module):
...
@@ -517,7 +520,7 @@ class BackboneUpdate(nn.Module):
class
StructureModuleTransitionLayer
(
nn
.
Module
):
class
StructureModuleTransitionLayer
(
nn
.
Module
):
def
__init__
(
self
,
c
):
def
__init__
(
self
,
c
:
int
):
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
self
.
c
=
c
self
.
c
=
c
...
@@ -528,7 +531,7 @@ class StructureModuleTransitionLayer(nn.Module):
...
@@ -528,7 +531,7 @@ class StructureModuleTransitionLayer(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
s
):
def
forward
(
self
,
s
:
torch
.
Tensor
):
s_initial
=
s
s_initial
=
s
s
=
self
.
linear_1
(
s
)
s
=
self
.
linear_1
(
s
)
s
=
self
.
relu
(
s
)
s
=
self
.
relu
(
s
)
...
@@ -542,7 +545,7 @@ class StructureModuleTransitionLayer(nn.Module):
...
@@ -542,7 +545,7 @@ class StructureModuleTransitionLayer(nn.Module):
class
StructureModuleTransition
(
nn
.
Module
):
class
StructureModuleTransition
(
nn
.
Module
):
def
__init__
(
self
,
c
,
num_layers
,
dropout_rate
):
def
__init__
(
self
,
c
:
int
,
num_layers
:
int
,
dropout_rate
:
float
):
super
(
StructureModuleTransition
,
self
).
__init__
()
super
(
StructureModuleTransition
,
self
).
__init__
()
self
.
c
=
c
self
.
c
=
c
...
@@ -557,7 +560,7 @@ class StructureModuleTransition(nn.Module):
...
@@ -557,7 +560,7 @@ class StructureModuleTransition(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
layer_norm
=
LayerNorm
(
self
.
c
)
self
.
layer_norm
=
LayerNorm
(
self
.
c
)
def
forward
(
self
,
s
)
:
def
forward
(
self
,
s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
l
in
self
.
layers
:
for
l
in
self
.
layers
:
s
=
l
(
s
)
s
=
l
(
s
)
...
@@ -570,22 +573,22 @@ class StructureModuleTransition(nn.Module):
...
@@ -570,22 +573,22 @@ class StructureModuleTransition(nn.Module):
class
StructureModule
(
nn
.
Module
):
class
StructureModule
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
c_s
,
c_s
:
int
,
c_z
,
c_z
:
int
,
c_ipa
,
c_ipa
:
int
,
c_resnet
,
c_resnet
:
int
,
no_heads_ipa
,
no_heads_ipa
:
int
,
no_qk_points
,
no_qk_points
:
int
,
no_v_points
,
no_v_points
:
int
,
dropout_rate
,
dropout_rate
:
float
,
no_blocks
,
no_blocks
:
int
,
no_transition_layers
,
no_transition_layers
:
int
,
no_resnet_blocks
,
no_resnet_blocks
:
int
,
no_angles
,
no_angles
:
int
,
trans_scale_factor
,
trans_scale_factor
:
float
,
epsilon
,
epsilon
:
float
,
inf
,
inf
:
float
,
is_multimer
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -621,6 +624,8 @@ class StructureModule(nn.Module):
...
@@ -621,6 +624,8 @@ class StructureModule(nn.Module):
Small number used in angle resnet normalization
Small number used in angle resnet normalization
inf:
inf:
Large number used for attention masking
Large number used for attention masking
is_multimer:
whether running under multimer mode
"""
"""
super
(
StructureModule
,
self
).
__init__
()
super
(
StructureModule
,
self
).
__init__
()
...
@@ -673,7 +678,10 @@ class StructureModule(nn.Module):
...
@@ -673,7 +678,10 @@ class StructureModule(nn.Module):
self
.
dropout_rate
,
self
.
dropout_rate
,
)
)
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
if
is_multimer
:
self
.
bb_update
=
QuatRigid
(
self
.
c_s
,
full_quat
=
False
)
else
:
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
self
.
angle_resnet
=
AngleResnet
(
self
.
angle_resnet
=
AngleResnet
(
self
.
c_s
,
self
.
c_s
,
...
@@ -683,13 +691,13 @@ class StructureModule(nn.Module):
...
@@ -683,13 +691,13 @@ class StructureModule(nn.Module):
self
.
epsilon
,
self
.
epsilon
,
)
)
def
forward
(
def
_
forward
_monomer
(
self
,
self
,
s
,
s
:
torch
.
Tensor
,
z
,
z
:
torch
.
Tensor
,
aatype
,
aatype
:
torch
.
Tensor
,
mask
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
)
->
Dict
[
str
,
Any
]
:
"""
"""
Args:
Args:
s:
s:
...
@@ -785,7 +793,103 @@ class StructureModule(nn.Module):
...
@@ -785,7 +793,103 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
def
_forward_multimer
(
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
str
,
Any
]:
if
mask
is
None
:
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
# [*, N, C_s]
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
# [*, N, C_s]
s_initial
=
s
s
=
self
.
linear_in
(
s
)
# [*, N]
rigids
=
Rigid3Array
.
identity
(
s
.
shape
[:
-
1
],
s
.
device
,
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
@
self
.
bb_update
(
s
)
# [*, N, 7, 2]
unnormalized_angles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
rigids
.
scale_translation
(
self
.
trans_scale_factor
),
angles
,
aatype
,
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
all_frames_to_global
,
aatype
,
)
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
.
to_tensor
(),
}
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
def
forward
(
self
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
self
.
is_multimer
:
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
if
self
.
default_frames
is
None
:
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
...
@@ -814,17 +918,24 @@ class StructureModule(nn.Module):
...
@@ -814,17 +918,24 @@ class StructureModule(nn.Module):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
r
:
Union
[
Rigid
,
Rigid3Array
],
alpha
:
torch
.
Tensor
,
f
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
self
,
r
:
Union
[
Rigid
,
Rigid3Array
]
,
f
# [*, N, 8] # [*, N]
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
if
type
(
r
)
==
Rigid
:
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
elif
type
(
r
)
==
Rigid3Array
:
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
else
:
raise
ValueError
(
"Unknown rigid type"
)
return
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
r
,
r
,
f
,
f
,
...
...
fastfold/utils/feats.py
View file @
9c0e7519
...
@@ -18,10 +18,12 @@ import math
...
@@ -18,10 +18,12 @@ import math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
fastfold.common
import
protein
from
fastfold.common
import
protein
import
fastfold.common.residue_constants
as
rc
import
fastfold.common.residue_constants
as
rc
from
fastfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
fastfold.utils.geometry.rotation_matrix
import
Rot3Array
from
fastfold.utils.rigid_utils
import
Rotation
,
Rigid
from
fastfold.utils.rigid_utils
import
Rotation
,
Rigid
from
fastfold.utils.tensor_utils
import
(
from
fastfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
...
@@ -36,7 +38,7 @@ def dgram_from_positions(
...
@@ -36,7 +38,7 @@ def dgram_from_positions(
max_bin
:
float
=
50.75
,
max_bin
:
float
=
50.75
,
no_bins
:
float
=
39
,
no_bins
:
float
=
39
,
inf
:
float
=
1e8
,
inf
:
float
=
1e8
,
):
)
->
torch
.
Tensor
:
dgram
=
torch
.
sum
(
dgram
=
torch
.
sum
(
(
pos
[...,
None
,
:]
-
pos
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
(
pos
[...,
None
,
:]
-
pos
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
)
...
@@ -46,8 +48,9 @@ def dgram_from_positions(
...
@@ -46,8 +48,9 @@ def dgram_from_positions(
return
dgram
return
dgram
def
pseudo_beta_fn
(
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
aatype
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_masks
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
is_gly
=
aatype
==
rc
.
restype_order
[
"G"
]
is_gly
=
aatype
==
rc
.
restype_order
[
"G"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
cb_idx
=
rc
.
atom_order
[
"CB"
]
cb_idx
=
rc
.
atom_order
[
"CB"
]
...
@@ -65,10 +68,10 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
...
@@ -65,10 +68,10 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
)
)
return
pseudo_beta
,
pseudo_beta_mask
return
pseudo_beta
,
pseudo_beta_mask
else
:
else
:
return
pseudo_beta
return
pseudo_beta
,
None
def
atom14_to_atom37
(
atom14
,
batch
):
def
atom14_to_atom37
(
atom14
,
batch
:
Dict
[
str
,
Any
]
):
atom37_data
=
batched_gather
(
atom37_data
=
batched_gather
(
atom14
,
atom14
,
batch
[
"residx_atom37_to_atom14"
],
batch
[
"residx_atom37_to_atom14"
],
...
@@ -81,19 +84,15 @@ def atom14_to_atom37(atom14, batch):
...
@@ -81,19 +84,15 @@ def atom14_to_atom37(atom14, batch):
return
atom37_data
return
atom37_data
def
build_template_angle_feat
(
template_feats
)
:
def
build_template_angle_feat
(
template_feats
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
template_aatype
=
template_feats
[
"template_aatype"
]
template_aatype
=
template_feats
[
"template_aatype"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
template_feats
[
alt_torsion_angles_sin_cos
=
template_feats
[
"template_alt_torsion_angles_sin_cos"
]
"template_alt_torsion_angles_sin_cos"
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
template_angle_feat
=
torch
.
cat
(
template_angle_feat
=
torch
.
cat
(
[
[
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
torsion_angles_sin_cos
.
reshape
(
torsion_angles_sin_cos
.
reshape
(
*
torsion_angles_sin_cos
.
shape
[:
-
2
],
14
),
*
torsion_angles_sin_cos
.
shape
[:
-
2
],
14
),
alt_torsion_angles_sin_cos
.
reshape
(
alt_torsion_angles_sin_cos
.
reshape
(
*
alt_torsion_angles_sin_cos
.
shape
[:
-
2
],
14
*
alt_torsion_angles_sin_cos
.
shape
[:
-
2
],
14
),
),
...
@@ -106,22 +105,20 @@ def build_template_angle_feat(template_feats):
...
@@ -106,22 +105,20 @@ def build_template_angle_feat(template_feats):
def
build_template_pair_feat
(
def
build_template_pair_feat
(
batch
,
batch
:
Dict
[
str
,
Any
],
min_bin
,
max_bin
,
no_bins
,
min_bin
:
float
,
use_unit_vector
=
False
,
max_bin
:
float
,
eps
=
1e-20
,
inf
=
1e8
no_bins
:
int
,
use_unit_vector
:
bool
=
False
,
eps
:
float
=
1e-20
,
inf
:
float
=
1e8
,
):
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
# Compute distogram (this seems to differ slightly from Alg. 5)
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb
=
batch
[
"template_pseudo_beta"
]
tpb
=
batch
[
"template_pseudo_beta"
]
dgram
=
torch
.
sum
(
dgram
=
dgram_from_positions
(
tpb
,
min_bin
,
max_bin
,
no_bins
,
inf
)
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[
1
:],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
...
@@ -137,9 +134,7 @@ def build_template_pair_feat(
...
@@ -137,9 +134,7 @@ def build_template_pair_feat(
)
)
)
)
to_concat
.
append
(
to_concat
.
append
(
aatype_one_hot
[...,
None
,
:].
expand
(
aatype_one_hot
[...,
None
,
:].
expand
(
*
aatype_one_hot
.
shape
[:
-
2
],
-
1
,
n_res
,
-
1
)
*
aatype_one_hot
.
shape
[:
-
2
],
-
1
,
n_res
,
-
1
)
)
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"CA"
,
"C"
]]
...
@@ -152,19 +147,17 @@ def build_template_pair_feat(
...
@@ -152,19 +147,17 @@ def build_template_pair_feat(
points
=
rigids
.
get_trans
()[...,
None
,
:,
:]
points
=
rigids
.
get_trans
()[...,
None
,
:,
:]
rigid_vec
=
rigids
[...,
None
].
invert_apply
(
points
)
rigid_vec
=
rigids
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
rigid_vec
**
2
,
dim
=-
1
))
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
rigid_vec
**
2
,
dim
=-
1
))
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
template_mask
=
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
rigid_vec
*
inv_distance_scalar
[...,
None
]
unit_vector
=
rigid_vec
*
inv_distance_scalar
[...,
None
]
if
(
not
use_unit_vector
)
:
if
not
use_unit_vector
:
unit_vector
=
unit_vector
*
0.
unit_vector
=
unit_vector
*
0.
0
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
to_concat
.
append
(
template_mask_2d
[...,
None
])
...
@@ -175,7 +168,7 @@ def build_template_pair_feat(
...
@@ -175,7 +168,7 @@ def build_template_pair_feat(
return
act
return
act
def
build_extra_msa_feat
(
batch
)
:
def
build_extra_msa_feat
(
batch
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
msa_1hot
=
nn
.
functional
.
one_hot
(
batch
[
"extra_msa"
],
23
)
msa_1hot
=
nn
.
functional
.
one_hot
(
batch
[
"extra_msa"
],
23
)
msa_feat
=
[
msa_feat
=
[
msa_1hot
,
msa_1hot
,
...
@@ -186,11 +179,11 @@ def build_extra_msa_feat(batch):
...
@@ -186,11 +179,11 @@ def build_extra_msa_feat(batch):
def
torsion_angles_to_frames
(
def
torsion_angles_to_frames
(
r
:
Rigid
,
r
:
Union
[
Rigid3Array
,
Rigid
]
,
alpha
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
)
->
Union
[
Rigid
,
Rigid3Array
]
:
# [*, N, 8, 4, 4]
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
default_4x4
=
rrgdf
[
aatype
,
...]
...
@@ -203,9 +196,7 @@ def torsion_angles_to_frames(
...
@@ -203,9 +196,7 @@ def torsion_angles_to_frames(
bb_rot
[...,
1
]
=
1
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
alpha
=
torch
.
cat
([
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# Produces rotation matrices of the form:
...
@@ -216,16 +207,26 @@ def torsion_angles_to_frames(
...
@@ -216,16 +207,26 @@ def torsion_angles_to_frames(
# ]
# ]
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
if
type
(
r
)
==
Rigid3Array
:
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
))
elif
type
(
r
)
==
Rigid
:
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
else
:
raise
TypeError
(
f
"Wrong type of Rigid:
{
type
(
r
)
}
"
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
if
type
(
r
)
==
Rigid3Array
:
all_rots
=
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
elif
type
(
r
)
==
Rigid
:
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
else
:
raise
TypeError
(
f
"Wrong type of Rigid:
{
type
(
r
)
}
"
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -236,15 +237,26 @@ def torsion_angles_to_frames(
...
@@ -236,15 +237,26 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
Rigid
.
cat
(
if
type
(
all_frames
)
==
Rigid3Array
:
[
all_frames_to_bb
=
Rigid3Array
.
cat
(
all_frames
[...,
:
5
],
[
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
all_frames
[...,
:
5
],
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
],
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
dim
=-
1
,
],
)
dim
=-
1
,
)
elif
type
(
all_frames
)
==
Rigid
:
all_frames_to_bb
=
Rigid
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
...
@@ -252,13 +264,13 @@ def torsion_angles_to_frames(
...
@@ -252,13 +264,13 @@ def torsion_angles_to_frames(
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
Union
[
Rigid3Array
,
Rigid
]
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
default_frames
,
default_frames
:
torch
.
Tensor
,
group_idx
,
group_idx
:
torch
.
Tensor
,
atom_mask
,
atom_mask
:
torch
.
Tensor
,
lit_positions
,
lit_positions
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
# [*, N, 14, 4, 4]
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
aatype
,
...]
default_4x4
=
default_frames
[
aatype
,
...]
...
@@ -266,21 +278,30 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -266,21 +278,30 @@ def frames_and_literature_positions_to_atom14_pos(
group_mask
=
group_idx
[
aatype
,
...]
group_mask
=
group_idx
[
aatype
,
...]
# [*, N, 14, 8]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
if
type
(
r
)
==
Rigid3Array
:
group_mask
,
group_mask
=
nn
.
functional
.
one_hot
(
num_classes
=
default_frames
.
shape
[
-
3
],
group_mask
.
long
(),
)
num_classes
=
default_frames
.
shape
[
-
3
],
)
elif
type
(
r
)
==
Rigid
:
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
else
:
raise
TypeError
(
f
"Wrong type of Rigid:
{
type
(
r
)
}
"
)
# [*, N, 14, 8]
# [*, N, 14, 8]
t_atoms_to_global
=
r
[...,
None
,
:]
*
group_mask
t_atoms_to_global
=
r
[...,
None
,
:]
*
group_mask
# [*, N, 14]
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
))
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14, 1]
# [*, N, 14, 1]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
if
type
(
r
)
==
Rigid
:
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
elif
type
(
r
)
==
Rigid3Array
:
atom_mask
=
atom_mask
[
aatype
,
...]
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
fastfold/utils/import_weights.py
View file @
9c0e7519
...
@@ -39,6 +39,12 @@ class ParamType(Enum):
...
@@ -39,6 +39,12 @@ class ParamType(Enum):
LinearWeightOPM
=
partial
(
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
)
LinearWeightMultimer
=
partial
(
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
reshape
(
w
.
shape
[
0
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearBiasMultimer
=
partial
(
lambda
w
:
w
.
reshape
(
-
1
))
Other
=
partial
(
lambda
w
:
w
)
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
...
@@ -121,29 +127,30 @@ def assign(translation_dict, orig_weights):
...
@@ -121,29 +127,30 @@ def assign(translation_dict, orig_weights):
print
(
weights
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
raise
raise
def
get_translation_dict
(
model
,
is_multimer
:
bool
=
False
):
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
#######################
#######################
# Some templates
# Some templates
#######################
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMultimer
)
)
LinearBiasMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
))
LinearParams
=
lambda
l
:
{
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"weights"
:
LinearWeight
(
l
.
weight
),
"bias"
:
LinearBias
(
l
.
bias
),
"bias"
:
LinearBias
(
l
.
bias
),
}
}
LinearParamsMultimer
=
lambda
l
:
{
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
}
LayerNormParams
=
lambda
l
:
{
LayerNormParams
=
lambda
l
:
{
"scale"
:
Param
(
l
.
weight
),
"scale"
:
Param
(
l
.
weight
),
"offset"
:
Param
(
l
.
bias
),
"offset"
:
Param
(
l
.
bias
),
...
@@ -239,7 +246,43 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -239,7 +246,43 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
# New style IPA param
# "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
# New style IPA param
# "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
PointProjectionParams
=
lambda
pp
:
{
"point_projection"
:
LinearParamsMultimer
(
pp
.
linear
,
),
}
IPAParamsMultimer
=
lambda
ipa
:
{
"q_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_q
.
weight
,
),
},
"k_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_k
.
weight
,
),
},
"v_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_v
.
weight
,
),
},
"q_point_projection"
:
PointProjectionParams
(
ipa
.
linear_q_points
),
"k_point_projection"
:
PointProjectionParams
(
ipa
.
linear_k_points
),
"v_point_projection"
:
PointProjectionParams
(
ipa
.
linear_v_points
),
"trainable_point_weights"
:
Param
(
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
),
...
@@ -278,56 +321,54 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -278,56 +321,54 @@ def import_jax_weights_(model, npz_path, version="model_1"):
msa_col_att_params
=
MSAColAttParams
(
b
.
msa_att_col
)
msa_col_att_params
=
MSAColAttParams
(
b
.
msa_att_col
)
d
=
{
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"outer_product_mean"
:
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
"triangle_multiplication_outgoing"
:
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
core
.
tri_mul_in
),
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
"triangle_multiplication_incoming"
:
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
TriMulInParams
(
b
.
core
.
tri_mul_in
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
}
}
return
d
return
d
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
FoldIterationParams
=
lambda
sm
:
{
def
FoldIterationParams
(
sm
):
"invariant_point_attention"
:
IPAParams
(
sm
.
ipa
),
d
=
{
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"invariant_point_attention"
:
IPAParamsMultimer
(
sm
.
ipa
)
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
if
is_multimer
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
else
IPAParams
(
sm
.
ipa
),
"transition_2"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_3
),
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"transition_layer_norm"
:
LayerNormParams
(
sm
.
transition
.
layer_norm
),
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
"rigid_sidechain"
:
{
"transition_2"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_3
),
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"transition_layer_norm"
:
LayerNormParams
(
sm
.
transition
.
layer_norm
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"rigid_sidechain"
:
{
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
},
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
}
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
},
}
if
is_multimer
:
d
.
pop
(
"affine_update"
)
d
[
"quat_rigid"
]
=
{
"rigid"
:
LinearParams
(
sm
.
bb_update
.
linear
)}
return
d
############################
############################
# translations dict overflow
# translations dict overflow
############################
############################
tps_blocks
=
model
.
template_embedder
.
template_pair_stack
.
blocks
tps_blocks
=
model
.
template_pair_stack
.
blocks
tps_blocks_params
=
stacked
([
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
])
tps_blocks_params
=
stacked
(
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
...
@@ -335,81 +376,175 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -335,81 +376,175 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
translations
=
{
if
not
is_multimer
:
"evoformer"
:
{
translations
=
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"evoformer"
:
{
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
model
.
recycling_embedder
.
layer_norm_m
"prev_msa_first_row_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_m
"prev_pair_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_z
"prev_pair_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_z
"pair_activiations"
:
LinearParams
(
),
model
.
input_embedder
.
linear_relpos
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
),
"template_embedding"
:
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
"embedding2d"
:
LinearParams
(
model
.
template_embedder
.
template_pair_embedder
.
linear
model
.
template_pair_embedder
.
linear
),
),
"template_pair_stack"
:
{
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
},
"
output_layer_norm"
:
LayerNorm
Params
(
"
attention"
:
Attention
Params
(
model
.
template_
pair_stack
.
layer_norm
model
.
template_
embedder
.
template_pointwise_att
.
mha
),
),
},
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
},
"extra_msa_activations"
:
LinearParams
(
"structure_module"
:
{
model
.
extra_msa_embedder
.
linear
"single_layer_norm"
:
LayerNormParams
(
),
model
.
structure_module
.
layer_norm_s
"extra_msa_stack"
:
ems_blocks_params
,
),
"template_single_embedding"
:
LinearParams
(
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
model
.
template_angle_embedder
.
linear_1
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
"template_projection"
:
LinearParams
(
},
model
.
template_angle_embedder
.
linear_2
"predicted_lddt_head"
:
{
),
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"evoformer_iteration"
:
evo_blocks_params
,
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
},
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"structure_module"
:
{
},
"single_layer_norm"
:
LayerNormParams
(
"distogram_head"
:
{
model
.
structure_module
.
layer_norm_s
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
),
},
"initial_projection"
:
LinearParams
(
"experimentally_resolved_head"
:
{
model
.
structure_module
.
linear_in
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
),
},
"pair_layer_norm"
:
LayerNormParams
(
"masked_msa_head"
:
{
model
.
structure_module
.
layer_norm_z
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
),
},
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
}
},
else
:
"predicted_lddt_head"
:
{
temp_embedder
=
model
.
template_embedder
"input_layer_norm"
:
LayerNormParams
(
translations
=
{
model
.
aux_heads
.
plddt
.
layer_norm
"evoformer"
:
{
),
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
},
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"distogram_head"
:
{
"prev_msa_first_row_norm"
:
LayerNormParams
(
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
model
.
recycling_embedder
.
layer_norm_m
},
),
"experimentally_resolved_head"
:
{
"prev_pair_norm"
:
LayerNormParams
(
"logits"
:
LinearParams
(
model
.
recycling_embedder
.
layer_norm_z
model
.
aux_heads
.
experimentally_resolved
.
linear
),
),
"~_relative_encoding"
:
{
},
"position_activations"
:
LinearParams
(
"masked_msa_head"
:
{
model
.
input_embedder
.
linear_relpos
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
),
},
},
}
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
},
"template_projection"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_projector
,
),
"template_single_embedding"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
is_multimer
=
(
"multimer"
in
version
))
no_templ
=
[
no_templ
=
[
"model_3"
,
"model_3"
,
...
...
fastfold/utils/inject_fastnn.py
View file @
9c0e7519
...
@@ -266,7 +266,7 @@ def inject_extraMsaBlock(model):
...
@@ -266,7 +266,7 @@ def inject_extraMsaBlock(model):
def
inject_templatePairBlock
(
model
):
def
inject_templatePairBlock
(
model
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
target_module
=
model
.
template_pair_stack
.
blocks
target_module
=
model
.
template_
embedder
.
template_
pair_stack
.
blocks
fastfold_blocks
=
nn
.
ModuleList
()
fastfold_blocks
=
nn
.
ModuleList
()
for
block_id
,
ori_block
in
enumerate
(
target_module
):
for
block_id
,
ori_block
in
enumerate
(
target_module
):
c_t
=
ori_block
.
c_t
c_t
=
ori_block
.
c_t
...
@@ -294,7 +294,7 @@ def inject_templatePairBlock(model):
...
@@ -294,7 +294,7 @@ def inject_templatePairBlock(model):
fastfold_block
.
eval
()
fastfold_block
.
eval
()
fastfold_blocks
.
append
(
fastfold_block
)
fastfold_blocks
.
append
(
fastfold_block
)
model
.
template_pair_stack
.
blocks
=
fastfold_blocks
model
.
template_
embedder
.
template_
pair_stack
.
blocks
=
fastfold_blocks
def
inject_fastnn
(
model
):
def
inject_fastnn
(
model
):
...
...
inference.py
View file @
9c0e7519
...
@@ -159,7 +159,7 @@ def main(args):
...
@@ -159,7 +159,7 @@ def main(args):
print
(
"Generating features..."
)
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
global_is_multimer
:
if
global_is_multimer
:
print
(
"multimer mode"
)
print
(
"
running in
multimer mode
...
"
)
feature_dict
=
pickle
.
load
(
open
(
"/home/lcmql/data/features_pdb1o5d.pkl"
,
"rb"
))
feature_dict
=
pickle
.
load
(
open
(
"/home/lcmql/data/features_pdb1o5d.pkl"
,
"rb"
))
else
:
else
:
if
(
args
.
use_precomputed_alignments
is
None
):
if
(
args
.
use_precomputed_alignments
is
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment