Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
39a6d0e6
Commit
39a6d0e6
authored
Apr 09, 2023
by
Christina Floristean
Browse files
Merging in main branch
parents
d8ee9c5f
84659c93
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2298 additions
and
649 deletions
+2298
-649
openfold/model/evoformer.py
openfold/model/evoformer.py
+567
-133
openfold/model/heads.py
openfold/model/heads.py
+9
-1
openfold/model/model.py
openfold/model/model.py
+203
-99
openfold/model/msa.py
openfold/model/msa.py
+104
-38
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+38
-8
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+6
-6
openfold/model/primitives.py
openfold/model/primitives.py
+179
-61
openfold/model/structure_module.py
openfold/model/structure_module.py
+259
-107
openfold/model/template.py
openfold/model/template.py
+360
-57
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+37
-17
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+331
-36
openfold/np/__init__.py
openfold/np/__init__.py
+0
-16
openfold/np/protein.py
openfold/np/protein.py
+158
-25
openfold/np/relax/__init__.py
openfold/np/relax/__init__.py
+0
-16
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+25
-4
openfold/np/relax/cleanup.py
openfold/np/relax/cleanup.py
+8
-2
openfold/np/relax/relax.py
openfold/np/relax/relax.py
+3
-0
openfold/np/relax/utils.py
openfold/np/relax/utils.py
+8
-2
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+3
-3
openfold/utils/__init__.py
openfold/utils/__init__.py
+0
-18
No files found.
openfold/model/evoformer.py
View file @
39a6d0e6
...
...
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
sys
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Sequence
,
Optional
from
functools
import
partial
from
openfold.model.primitives
import
Linear
,
LayerNorm
...
...
@@ -29,6 +29,7 @@ from openfold.model.msa import (
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.triangular_attention
import
(
TriangleAttention
,
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
...
...
@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.chunk_utils
import
chunk_layer
,
ChunkSizeTuner
from
openfold.utils.tensor_utils
import
add
class
MSATransition
(
nn
.
Module
):
...
...
@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
def
_transition
(
self
,
m
,
mask
):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
linear_1
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
linear_2
(
m
)
*
mask
...
...
@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
...
...
@@ -140,13 +141,13 @@ class PairStack(nn.Module):
c_hidden_mul
,
)
self
.
tri_att_start
=
TriangleAttention
StartingNode
(
self
.
tri_att_start
=
TriangleAttention
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttention
EndingNode
(
self
.
tri_att_end
=
TriangleAttention
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
...
...
@@ -159,32 +160,109 @@ class PairStack(nn.Module):
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
return
z
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
device
=
z
.
device
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
...
...
@@ -248,41 +326,134 @@ class EvoformerBlock(nn.Module):
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
m
:
Optional
[
torch
.
Tensor
]
,
z
:
Optional
[
torch
.
Tensor
]
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
),
inplace
=
inplace_safe
,
)
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
if
not
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
z
=
self
.
pair_stack
(
z
,
if
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
elif
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
device
=
input_tensors
[
0
].
device
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
=
self
.
pair_stack
(
input_tensors
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
)
return
m
,
z
...
...
@@ -358,63 +529,140 @@ class ExtraMSABlock(nn.Module):
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
m
:
Optional
[
torch
.
Tensor
]
,
z
:
Optional
[
torch
.
Tensor
]
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
torch
.
is_grad_enabled
()):
m1
=
m1
+
m2
else
:
m1
+=
m2
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
return
m1
m
,
z
=
input_tensors
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_memory_efficient_kernel
=
not
_chunk_logits
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
use_lma
and
m
.
is_cuda
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
),
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
def
fn
(
input_tensors
):
m
,
z
=
input_tensors
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
),
inplace
=
inplace_safe
,
)
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
))
def
fn
(
m
,
z
):
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
if
not
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
z
=
self
.
pair_stack
(
z
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
=
self
.
pair_stack
(
input_tensors
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
)
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
m
,
z
=
checkpoint_fn
(
fn
,
input_tensors
)
else
:
m
,
z
=
fn
(
m
,
z
)
m
,
z
=
fn
(
input_tensors
)
return
m
,
z
...
...
@@ -446,6 +694,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
**
kwargs
,
):
"""
...
...
@@ -482,6 +731,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
super
(
EvoformerStack
,
self
).
__init__
()
...
...
@@ -511,14 +762,114 @@ class EvoformerStack(nn.Module):
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
,
use_flash
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
,
**
kwargs
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# We don't want to write in-place during chunk tuning runs
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
return
blocks
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
m
,
z
=
input_tensors
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
m:
...
...
@@ -529,6 +880,13 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
...
...
@@ -536,33 +894,31 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
"""
blocks
=
self
.
_prep_blocks
(
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
blocks_per_ckpt
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
...
...
@@ -570,7 +926,6 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
...
...
@@ -589,14 +944,13 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_
msa_attn
:
bool
=
False
,
tune_
chunk_
size
:
bool
=
False
,
**
kwargs
,
):
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
...
...
@@ -614,16 +968,107 @@ class ExtraMSAStack(nn.Module):
opm_first
=
opm_first
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
False
,
)
self
.
blocks
.
append
(
block
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
return
blocks
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
return
input_tensors
[
1
]
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
Non
e
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
Non
e
,
use_lma
:
bool
=
Fals
e
,
inplace_safe
:
bool
=
Fals
e
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -632,6 +1077,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
...
...
@@ -639,35 +1086,22 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
if
(
not
self
.
chunk_msa_attn
):
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
self
.
_prep_blocks
(
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
return
z
openfold/model/heads.py
View file @
39a6d0e6
...
...
@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm
,
compute_predicted_aligned_error
,
)
from
openfold.utils.precision_utils
import
is_fp16_enabled
class
AuxiliaryHeads
(
nn
.
Module
):
...
...
@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
):
# [*, N, N, C_z]
def
_
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
Args:
z:
...
...
@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
logits
=
self
.
linear
(
z
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
return
logits
def
forward
(
self
,
z
):
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
z
.
float
())
else
:
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
...
...
openfold/model/model.py
View file @
39a6d0e6
...
...
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
weakref
import
torch
import
torch.nn
as
nn
...
...
@@ -34,12 +35,26 @@ from openfold.model.embedders import (
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
embed_templates_average
,
embed_templates_offload
,
)
import
openfold.np.residue_constants
as
residue_constants
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
from
openfold.utils.loss
import
(
compute_plddt
,
)
from
openfold.utils.tensor_utils
import
(
add
,
dict_multimap
,
tensor_tree_map
,
)
...
...
@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
config
=
config
.
model
template_config
=
config
.
template
extra_msa_config
=
config
.
extra_msa
self
.
config
=
config
.
model
self
.
template_config
=
self
.
config
.
template
self
.
extra_msa_config
=
self
.
config
.
extra_msa
# Main trunk + structure module
if
(
self
.
globals
.
is_multimer
):
self
.
input_embedder
=
InputEmbedderMultimer
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
else
:
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
self
.
config
[
"recycling_embedder"
],
)
if
(
self
.
globals
.
is_multimer
):
self
.
template_embedder
=
TemplateEmbedderMultimer
(
template_config
,
if
(
self
.
template_config
.
enabled
):
if
(
self
.
globals
.
is_multimer
):
self
.
template_embedder
=
TemplateEmbedderMultimer
(
self
.
template_config
,
)
else
:
self
.
template_embedder
=
TemplateEmbedder
(
self
.
template_config
,
)
if
(
self
.
extra_msa_config
.
enabled
):
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
self
.
extra_msa_config
[
"extra_msa_embedder"
],
)
else
:
self
.
template_embedder
=
TemplateEmbedder
(
template_config
,
self
.
extra_msa_stack
=
ExtraMSAStack
(
**
self
.
extra_msa_config
[
"extra_msa_stack"
],
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
**
extra_msa_config
[
"extra_msa_stack"
],
)
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
self
.
config
[
"evoformer_stack"
],
)
self
.
structure_module
=
StructureModule
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"structure_module"
],
**
self
.
config
[
"structure_module"
],
)
self
.
aux_heads
=
AuxiliaryHeads
(
config
[
"heads"
],
self
.
config
[
"heads"
],
)
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
feats
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
(
self
.
globals
.
is_multimer
):
asym_id
=
feats
[
"asym_id"
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
template_embeds
=
self
.
template_embedder
(
batch
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
template_embeds
=
self
.
template_embedder
(
batch
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
)
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
return
template_embeds
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
# Primary output dictionary
outputs
=
{}
...
...
@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
seq_mask
=
feats
[
"seq_mask"
]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
]
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations
if
(
self
.
globals
.
is_multimer
):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
)
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
)
else
:
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
inplace_safe
=
inplace_safe
,
)
# Initialize the recycling embeddings, if needs be
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
...
...
@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
z
=
z
.
cpu
()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
z_prev
,
x_prev
,
inplace_safe
=
inplace_safe
,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
to
(
m_1_prev_emb
.
device
)
z
=
z
.
to
(
z_prev
.
device
)
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
z
+
=
z_prev_emb
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# Possibly prevents memory fragmentation
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
if
(
self
.
globals
.
is_multimer
):
asym_id
=
feats
[
"asym_id"
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
)
template_embeds
=
self
.
embed_templates
(
template_feats
,
feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
inplace_safe
=
inplace_safe
,
)
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
add
(
z
,
template_embeds
.
pop
(
"template_pair_embedding"
),
inplace_safe
,
)
if
(
self
.
config
.
template
.
embed_angles
or
(
self
.
globals
.
is_multimer
and
self
.
config
.
template
.
enabled
)
"template_single_embedding"
in
template_embeds
):
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
...
...
@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
a
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
if
(
self
.
globals
.
offload_inference
):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_offload
(
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
extra_msa_feat
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
.
dtype
),
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
(
self
.
globals
.
offload_inference
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_offload
(
input_tensors
,
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
del
z
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
s
,
z
,
outputs
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
inplace_safe
=
inplace_safe
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
...
@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev
=
m
[...,
0
,
:,
:]
# [*, N, N, C_z]
z_prev
=
z
z_prev
=
outputs
[
"pair"
]
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
...
...
@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
for
cycle_no
in
range
(
num_iters
):
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
...
...
@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug (PyTorch issue #65766)
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
...
...
@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
prevs
,
_recycle
=
(
num_iters
>
1
)
)
if
(
not
is_final_iter
):
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/msa.py
View file @
39a6d0e6
...
...
@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable
,
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
)
...
...
@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
use_memory_efficient_kernel
:
bool
,
biases
:
Optional
[
List
[
torch
.
Tensor
]],
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
flash_mask
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
mha
=
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
)
def
fn
(
m
,
biases
,
flash_mask
):
m
=
self
.
layer_norm_m
(
m
)
return
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
flash_mask
,
)
inputs
=
{
"m"
:
m
}
if
(
biases
is
not
None
):
inputs
[
"biases"
]
=
biases
else
:
fn
=
partial
(
fn
,
biases
=
None
)
if
(
use_flash
and
flash_mask
is
not
None
):
inputs
[
"flash_mask"
]
=
flash_mask
else
:
fn
=
partial
(
fn
,
flash_mask
=
None
)
return
chunk_layer
(
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
,
},
fn
,
inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
...
...
@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
# [*, N_seq, N_res]
...
...
@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
chunks
=
[]
for
i
in
range
(
0
,
z
.
shape
[
-
3
],
256
):
z_chunk
=
z
[...,
i
:
i
+
256
,
:,
:]
# [*, N_res, N_res, C_z]
z_chunk
=
self
.
layer_norm_z
(
z_chunk
)
# [*, N_res, N_res, no_heads]
z_chunk
=
self
.
linear_z
(
z_chunk
)
chunks
.
append
(
z_chunk
)
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
z
=
torch
.
cat
(
chunks
,
dim
=-
3
)
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
...
...
@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
inplace_safe
:
bool
=
False
)
->
torch
.
Tensor
:
"""
MSA attention with training-time chunking of the softmax computation.
...
...
@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
m
=
self
.
layer_norm_m
(
m
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
...
@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
)
if
(
use_flash
):
assert
z
is
None
biases
=
None
else
:
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
else
:
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
return
m
...
...
@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
def
fn
(
m
,
mask
):
m
=
self
.
layer_norm_m
(
m
)
return
self
.
global_attention
(
m
,
mask
,
use_lma
=
use_lma
)
return
chunk_layer
(
self
.
global_attentio
n
,
f
n
,
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
...
...
@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
...
@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
#
m = self.layer_norm_m(m)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
else
:
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/outer_product_mean.py
View file @
39a6d0e6
...
...
@@ -20,7 +20,8 @@ import torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.precision_utils
import
is_fp16_enabled
class
OuterProductMean
(
nn
.
Module
):
...
...
@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
# For some cursed reason making this distinction saves memory
if
(
len
(
out
)
==
1
):
outer
=
out
[
0
].
unsqueeze
(
0
)
else
:
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
def
forward
(
self
,
def
_
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
ln
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
b
=
self
.
linear_2
(
m
)
*
mask
a
=
self
.
linear_1
(
ln
)
a
=
a
*
mask
b
=
self
.
linear_2
(
ln
)
b
=
b
*
mask
del
ln
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
...
...
@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
if
(
inplace_safe
):
outer
/=
norm
else
:
outer
=
outer
/
norm
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
m
.
float
(),
mask
,
chunk_size
,
inplace_safe
)
else
:
return
self
.
_forward
(
m
,
mask
,
chunk_size
,
inplace_safe
)
openfold/model/pair_transition.py
View file @
39a6d0e6
...
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.
tensor
_utils
import
chunk_layer
from
openfold.utils.
chunk
_utils
import
chunk_layer
class
PairTransition
(
nn
.
Module
):
...
...
@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
def
_transition
(
self
,
z
,
mask
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
# [*, N_res, N_res, C_hidden]
z
=
self
.
linear_1
(
z
)
z
=
self
.
relu
(
z
)
# [*, N_res, N_res, C_z]
z
=
self
.
linear_2
(
z
)
*
mask
z
=
self
.
linear_2
(
z
)
z
=
z
*
mask
return
z
...
...
@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
else
:
...
...
openfold/model/primitives.py
View file @
39a6d0e6
...
...
@@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
importlib
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
deepspeed
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
(
deepspeed_is_installed
):
import
deepspeed
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
fa_is_installed
):
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
_chunk_slice
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
flatten_final_dims
,
_chunk_slice
,
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
out
=
1
for
n
in
nums
:
...
...
@@ -145,26 +160,26 @@ class Linear(nn.Linear):
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
0
)
if
init_fn
is
not
None
:
init_fn
(
self
.
weight
,
self
.
bias
)
else
:
if
init
==
"default"
:
lecun_normal_init_
(
self
.
weight
)
elif
init
==
"relu"
:
he_normal_init_
(
self
.
weight
)
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
if
bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
1.0
)
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
with
torch
.
no_grad
():
if
init_fn
is
not
None
:
init_fn
(
self
.
weight
,
self
.
bias
)
else
:
raise
ValueError
(
"Invalid init string."
)
if
init
==
"default"
:
lecun_normal_init_
(
self
.
weight
)
elif
init
==
"relu"
:
he_normal_init_
(
self
.
weight
)
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
if
bias
:
self
.
bias
.
fill_
(
1.0
)
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
else
:
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
...
...
@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
d
=
x
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
...
...
@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16
"""
d
=
t
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
else
:
...
...
@@ -403,8 +426,10 @@ class Attention(nn.Module):
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
use_flash
:
bool
=
False
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -423,29 +448,41 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
lma_
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
lma_
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if
(
biases
is
None
):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
if
(
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
)):
raise
ValueError
(
"If use_lma is specified, q_chunk_size and
kv_chunk_size must
"
"be provided"
"If use_lma is specified,
lma_
q_chunk_size and "
"
lma_kv_chunk_size must
be provided"
)
if
(
use_memory_efficient_kernel
and
use_lma
):
if
(
use_flash
and
biases
is
not
None
):
raise
ValueError
(
"Choose one of use_memory_efficient_kernel and use_lma"
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options
=
[
use_memory_efficient_kernel
,
use_lma
,
use_flash
]
if
(
sum
(
attn_options
)
>
1
):
raise
ValueError
(
"Choose at most one alternative attention algorithm"
)
if
(
biases
is
None
):
biases
=
[]
# [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# [*, Q, H, C_hidden]
if
is_fp16_enabled
():
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
...
...
@@ -459,7 +496,10 @@ class Attention(nn.Module):
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_q_chunk_size
,
lma_kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_flash
):
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
...
...
@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
...
...
@@ -511,20 +555,30 @@ class GlobalAttention(nn.Module):
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
softmax_no_cast
(
a
)
if
(
not
use_lma
):
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
a
+=
bias
a
=
softmax_no_cast
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
else
:
o
=
_lma
(
q
,
k
,
v
,
[
bias
],
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
...
...
@@ -552,12 +606,12 @@ def _lma(
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
2
],
k
.
shape
[
-
2
]
# [*,
Q
,
H
, C_hidden]
# [*,
H
,
Q
, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
large_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
...
...
@@ -566,24 +620,22 @@ def _lma(
weights
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
]
a
=
torch
.
einsum
(
"...
q
hd,...
k
hd->...hqk"
,
q_chunk
,
k_chunk
,
"...h
q
d,...h
k
d->...hqk"
,
q_chunk
,
k_chunk
,
)
for
b
in
small_bias_chunks
:
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...
v
hf,...
q
hv->...
q
hf"
,
v_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...h
v
f,...h
q
v->...h
q
f"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
...
...
@@ -595,14 +647,80 @@ def _lma(
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*
=
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*
=
max_diffs
chunk_values
=
chunk_values
*
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
=
chunk_weights
*
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
return
o
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
if
(
not
fa_is_installed
):
raise
ValueError
(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims
=
q
.
shape
[:
-
3
]
no_heads
,
n
,
c
=
q
.
shape
[
-
3
:]
dtype
=
q
.
dtype
q
=
q
.
half
()
k
=
k
.
half
()
v
=
v
.
half
()
kv_mask
=
kv_mask
.
half
()
# [*, B, N, H, C]
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
# [B_flat, N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
3
:])
k
=
k
.
reshape
(
-
1
,
*
k
.
shape
[
-
3
:])
v
=
v
.
reshape
(
-
1
,
*
v
.
shape
[
-
3
:])
# Flattened batch size
batch_size
=
q
.
shape
[
0
]
# [B_flat * N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
2
:])
q_max_s
=
n
q_cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
n
,
step
=
n
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
# [B_flat, N, 2, H, C]
kv
=
torch
.
stack
([
k
,
v
],
dim
=-
3
)
kv_shape
=
kv
.
shape
# [B_flat, N, 2 * H * C]
kv
=
kv
.
reshape
(
*
kv
.
shape
[:
-
3
],
-
1
)
kv_unpad
,
_
,
kv_cu_seqlens
,
kv_max_s
=
unpad_input
(
kv
,
kv_mask
)
kv_unpad
=
kv_unpad
.
reshape
(
-
1
,
*
kv_shape
[
-
3
:])
out
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv_unpad
,
q_cu_seqlens
,
kv_cu_seqlens
,
q_max_s
,
kv_max_s
,
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
)
# [*, B, N, H, C]
out
=
out
.
reshape
(
*
batch_dims
,
n
,
no_heads
,
c
)
out
=
out
.
to
(
dtype
=
dtype
)
return
out
openfold/model/structure_module.py
View file @
39a6d0e6
...
...
@@ -12,11 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
importlib
import
math
import
sys
from
operator
import
mul
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Sequence
,
Union
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
...
...
@@ -27,11 +31,12 @@ from openfold.np.residue_constants import (
)
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.geometry.vector
import
Vec3Array
,
square_euclidean_distance
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
)
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
...
...
@@ -39,6 +44,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims
,
)
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
class
AngleResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
):
...
...
@@ -164,6 +171,7 @@ class PointProjection(nn.Module):
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
num_points
=
num_points
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
...
...
@@ -173,22 +181,30 @@ class PointProjection(nn.Module):
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
],
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
,
)
if
isinstance
(
rigids
,
Rigid3Array
):
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
,
)
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
if
not
isinstance
(
rigids
,
Rigid3Array
):
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
2
],
self
.
no_heads
,
-
1
,
3
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
return
points_global
class
InvariantPointAttention
(
nn
.
Module
):
...
...
@@ -242,8 +258,8 @@ class InvariantPointAttention(nn.Module):
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
)
...
...
@@ -288,6 +304,9 @@ class InvariantPointAttention(nn.Module):
z
:
torch
.
Tensor
,
r
:
Union
[
Rigid
,
Rigid3Array
],
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -302,6 +321,11 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
#######################################
# Generate scalar and point activations
#######################################
...
...
@@ -312,7 +336,7 @@ class InvariantPointAttention(nn.Module):
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, P_qk]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
...
...
@@ -351,13 +375,25 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
)
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
a
=
torch
.
matmul
(
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
.
float
(),
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
else
:
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
...
...
@@ -369,7 +405,12 @@ class InvariantPointAttention(nn.Module):
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
if
(
inplace_safe
):
pt_att
*=
pt_att
else
:
pt_att
=
pt_att
**
2
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
...
...
@@ -378,7 +419,11 @@ class InvariantPointAttention(nn.Module):
head_weights
=
head_weights
*
math
.
sqrt
(
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
)
pt_att
=
pt_att
*
head_weights
if
(
inplace_safe
):
pt_att
*=
head_weights
else
:
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
...
...
@@ -388,9 +433,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
if
(
inplace_safe
):
a
+=
pt_att
del
pt_att
a
+=
square_mask
.
unsqueeze
(
-
3
)
# in-place softmax
attn_core_inplace_cuda
.
forward_
(
a
,
reduce
(
mul
,
a
.
shape
[:
-
1
]),
a
.
shape
[
-
1
],
)
else
:
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
################
# Compute output
...
...
@@ -419,13 +476,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, H, 3, N_res, P_v]
if
(
inplace_safe
):
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
o_pt
=
[
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
for
v
in
torch
.
unbind
(
v_pts
,
dim
=-
3
)
]
o_pt
=
torch
.
stack
(
o_pt
,
dim
=-
3
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
...
@@ -440,8 +506,11 @@ class InvariantPointAttention(nn.Module):
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
[
0
]
.
to
(
dtype
=
a
.
dtype
))
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
...
...
@@ -450,7 +519,7 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
[
0
]
.
dtype
)
)
return
s
...
...
@@ -611,11 +680,11 @@ class StructureModule(nn.Module):
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
#
T
o be lazily initialized later
self
.
default_frames
=
None
self
.
group_idx
=
None
self
.
atom_mask
=
None
self
.
lit_positions
=
None
#
Buffers t
o be lazily initialized later
#
self.default_frames
#
self.group_idx
#
self.atom_mask
#
self.lit_positions
self
.
layer_norm_s
=
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
...
...
@@ -655,62 +724,32 @@ class StructureModule(nn.Module):
self
.
no_angles
,
self
.
epsilon
,
)
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
def
_forward_monomer
(
self
,
s
,
z
,
def
_forward_monomer
(
self
,
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
"""
Args:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
...
@@ -719,7 +758,14 @@ class StructureModule(nn.Module):
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
s_initial
=
s
...
...
@@ -736,11 +782,19 @@ class StructureModule(nn.Module):
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
...
@@ -781,24 +835,35 @@ class StructureModule(nn.Module):
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"states"
:
s
,
}
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
def
_forward_multimer
(
self
,
s
,
z
,
aatype
,
mask
=
None
,
def
_forward_multimer
(
self
,
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
...
@@ -807,7 +872,14 @@ class StructureModule(nn.Module):
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
s_initial
=
s
...
...
@@ -821,7 +893,15 @@ class StructureModule(nn.Module):
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
...
...
@@ -848,13 +928,19 @@ class StructureModule(nn.Module):
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
.
to_tensor
()
,
"positions"
:
pred_xyz
,
}
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
...
...
@@ -863,10 +949,11 @@ class StructureModule(nn.Module):
def
forward
(
self
,
s
,
z
,
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
"""
Args:
...
...
@@ -882,8 +969,73 @@ class StructureModule(nn.Module):
A dictionary of outputs
"""
if
(
self
.
is_multimer
):
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
not
hasattr
(
self
,
"default_frames"
):
self
.
register_buffer
(
"default_frames"
,
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"group_idx"
):
self
.
register_buffer
(
"group_idx"
,
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"atom_mask"
):
self
.
register_buffer
(
"atom_mask"
,
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"lit_positions"
):
self
.
register_buffer
(
"lit_positions"
,
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
openfold/model/template.py
View file @
39a6d0e6
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
from
functools
import
partial
import
math
import
sys
from
typing
import
Optional
,
List
import
torch
...
...
@@ -34,10 +35,19 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.
tensor
_utils
import
(
from
openfold.utils.
chunk
_utils
import
(
chunk_layer
,
ChunkSizeTuner
,
)
from
openfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.utils.tensor_utils
import
(
add
,
permute_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
)
...
...
@@ -77,6 +87,7 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
z
,
...
...
@@ -84,7 +95,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
...
...
@@ -95,7 +106,9 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
# This module suffers greatly from a small chunk size
chunk_size
:
Optional
[
int
]
=
256
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -121,10 +134,10 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases
=
[
bias
]
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
if
chunk_size
is
not
None
and
not
self
.
training
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
,
use_lma
=
use_lma
)
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
...
...
@@ -186,74 +199,118 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
)
def
tri_att_start_end
(
self
,
single
,
single_mask
,
chunk_size
):
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
def
tri_att_start_end
(
self
,
single
,
_attn_chunk_size
,
single_mask
,
use_lma
,
inplace_safe
):
single
=
add
(
single
,
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace_safe
,
)
single
=
add
(
single
,
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace_safe
,
)
return
single
def
tri_mul_out_in
(
self
,
single
,
single_mask
):
single
=
single
+
self
.
dropout_row
(
s
elf
.
tri_mul_out
(
single
,
mask
=
single_mask
)
def
tri_mul_out_in
(
self
,
single
,
single_mask
,
inplace_safe
):
tmu_update
=
self
.
tri_mul_out
(
s
ingle
,
mask
=
single
_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
if
(
not
inplace_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
single
,
mask
=
single_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
return
single
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
if
self
.
tri_mul_first
:
single
=
self
.
tri_att_start_end
(
single
=
self
.
tri_mul_out_in
(
single
=
single
,
single_mask
=
single_mask
),
single_mask
=
single_mask
,
inplace_safe
=
inplace_safe
),
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
chunk_size
=
chunk_size
)
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
)
else
:
single
=
self
.
tri_mul_out_in
(
single
=
self
.
tri_att_start_end
(
single
=
single
,
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
chunk_size
=
chunk_size
),
single_mask
=
single_mask
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
)
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
),
single_mask
=
single_mask
,
inplace_safe
=
inplace_safe
)
single
=
add
(
single
,
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
),
inplace_safe
,
)
if
(
not
inplace_safe
):
single_templates
[
i
]
=
single
if
(
not
inplace_safe
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -273,6 +330,7 @@ class TemplatePairStack(nn.Module):
dropout_rate
,
tri_mul_first
,
blocks_per_ckpt
,
tune_chunk_size
:
bool
=
False
,
inf
=
1e9
,
**
kwargs
,
):
...
...
@@ -314,11 +372,18 @@ class TemplatePairStack(nn.Module):
self
.
layer_norm
=
LayerNorm
(
c_t
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
):
"""
...
...
@@ -335,16 +400,34 @@ class TemplatePairStack(nn.Module):
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
blocks
=
[
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
t
,
=
checkpoint_blocks
(
blocks
=
[
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
],
blocks
=
blocks
,
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
...
...
@@ -352,3 +435,223 @@ class TemplatePairStack(nn.Module):
t
=
self
.
layer_norm
(
t
)
return
t
def
embed_templates_offload
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
inplace_safe
=
False
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
).
squeeze
(
templ_dim
),
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
model
.
template_pair_embedder
(
t
)
# [*, 1, N, N, C_z]
t
=
model
.
template_pair_stack
(
t
.
unsqueeze
(
templ_dim
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
assert
(
sys
.
getrefcount
(
t
)
==
2
)
pair_embeds_cpu
.
append
(
t
.
cpu
())
del
t
# Preallocate the output tensor
t
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n
,
template_chunk_size
):
pair_chunks
=
[
p
[...,
i
:
i
+
template_chunk_size
,
:,
:]
for
p
in
pair_embeds_cpu
]
pair_chunk
=
torch
.
cat
(
pair_chunks
,
dim
=
templ_dim
).
to
(
device
=
z
.
device
)
z_chunk
=
z
[...,
i
:
i
+
template_chunk_size
,
:,
:]
att_chunk
=
model
.
template_pointwise_att
(
pair_chunk
,
z_chunk
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
t
[...,
i
:
i
+
template_chunk_size
,
:,
:]
=
att_chunk
del
pair_chunks
if
(
inplace_safe
):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
embed_templates_average
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
inplace_safe
=
False
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
out_tensor
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n_templ
,
templ_group_size
):
def
slice_template_tensor
(
t
):
s
=
[
slice
(
None
)
for
_
in
t
.
shape
]
s
[
templ_dim
]
=
slice
(
i
,
i
+
templ_group_size
)
return
t
[
s
]
template_feats
=
tensor_tree_map
(
slice_template_tensor
,
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
# [*, S_t, N, N, C_z]
t
=
model
.
template_pair_embedder
(
t
)
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
t
=
model
.
template_pointwise_att
(
t
,
z
,
template_mask
=
template_feats
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
denom
=
math
.
ceil
(
n_templ
/
templ_group_size
)
if
(
inplace_safe
):
t
/=
denom
else
:
t
=
t
/
denom
if
(
inplace_safe
):
out_tensor
+=
t
else
:
out_tensor
=
out_tensor
+
t
del
t
if
(
inplace_safe
):
out_tensor
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
out_tensor
=
out_tensor
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
out_tensor
})
return
ret
openfold/model/triangular_attention.py
View file @
39a6d0e6
...
...
@@ -21,8 +21,8 @@ import torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
)
...
...
@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
self
,
c_in
,
c_hidden
,
no_heads
,
starting
=
True
,
inf
=
1e9
):
"""
Args:
...
...
@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"triangle! triangle!"
mha_inputs
=
{
"q_x"
:
x
,
"kv_x"
:
x
,
"biases"
:
biases
,
}
return
chunk_layer
(
partial
(
self
.
mha
),
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
),
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
_out
=
x
if
inplace_safe
else
None
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -86,15 +99,14 @@ class TriangleAttention(nn.Module):
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
"""
if
mask
is
None
:
# [*, I, J]
mask
=
x
.
new_ones
(
x
.
shape
[:
-
1
],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if
not
self
.
starting
:
if
(
not
self
.
starting
):
x
=
x
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
...
...
@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
)
if
not
self
.
starting
:
if
(
not
self
.
starting
)
:
x
=
x
.
transpose
(
-
2
,
-
3
)
return
x
class
TriangleAttentionStartingNode
(
TriangleAttention
):
"""
Implements Algorithm 13.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
# Implements Algorithm 13
TriangleAttentionStartingNode
=
TriangleAttention
class
TriangleAttentionEndingNode
(
TriangleAttention
):
"""
Implements Algorithm 14.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
openfold/model/triangular_multiplicative_update.py
View file @
39a6d0e6
...
...
@@ -20,7 +20,9 @@ import torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
...
...
@@ -55,12 +57,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
if
(
self
.
_outgoing
):
a
=
permute_final_dims
(
a
,
(
2
,
0
,
1
))
b
=
permute_final_dims
(
b
,
(
2
,
1
,
0
))
else
:
a
=
permute_final_dims
(
a
,
(
2
,
1
,
0
))
b
=
permute_final_dims
(
b
,
(
2
,
0
,
1
))
if
(
_inplace_chunk_size
is
not
None
):
# To be replaced by torch vmap
for
i
in
range
(
0
,
a
.
shape
[
-
3
],
_inplace_chunk_size
):
a_chunk
=
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
b_chunk
=
b
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
=
(
torch
.
matmul
(
a_chunk
,
b_chunk
,
)
)
p
=
a
else
:
p
=
torch
.
matmul
(
a
,
b
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
,
a
=
True
):
if
(
a
):
linear_g
=
self
.
linear_a_g
linear_p
=
self
.
linear_a_p
else
:
linear_g
=
self
.
linear_b_g
linear_p
=
self
.
linear_b_p
pair
=
self
.
layer_norm_in
(
pair
)
p
=
linear_g
(
pair
)
p
.
sigmoid_
()
p
*=
linear_p
(
pair
)
p
*=
mask
p
=
permute_final_dims
(
p
,
(
2
,
0
,
1
))
return
p
def
compute_projection
(
pair
,
mask
,
a
=
True
,
chunked
=
True
):
need_transpose
=
self
.
_outgoing
^
a
if
(
not
chunked
):
p
=
compute_projection_helper
(
pair
,
mask
,
a
)
if
(
need_transpose
):
p
=
p
.
transpose
(
-
1
,
-
2
)
else
:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g
=
self
.
linear_a_g
if
a
else
self
.
linear_b_g
c
=
linear_g
.
bias
.
shape
[
-
1
]
out_shape
=
pair
.
shape
[:
-
3
]
+
(
c
,)
+
pair
.
shape
[
-
3
:
-
1
]
p
=
pair
.
new_zeros
(
out_shape
)
for
i
in
range
(
0
,
pair
.
shape
[
-
3
],
inplace_chunk_size
):
pair_chunk
=
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
mask_chunk
=
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
pair_chunk
=
compute_projection_helper
(
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
a
,
)
if
(
need_transpose
):
pair_chunk
=
pair_chunk
.
transpose
(
-
1
,
-
2
)
p
[...,
i
:
i
+
inplace_chunk_size
]
=
pair_chunk
else
:
p
[...,
i
:
i
+
inplace_chunk_size
,
:]
=
pair_chunk
del
pair_chunk
return
p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a
=
compute_projection
(
z
,
mask
,
True
,
chunked
=
True
)
if
(
inplace_chunk_size
is
not
None
):
n
=
a
.
shape
[
-
1
]
half_n
=
n
//
2
+
n
%
2
row_dim
=
-
3
col_dim
=
-
2
b_chunk_dim
=
row_dim
if
self
.
_outgoing
else
col_dim
def
empty_slicer
(
t
):
return
[
slice
(
None
)
for
_
in
t
.
shape
]
def
slice_tensor
(
t
,
start
,
end
,
dim
):
# Slices start:end from the dim dimension of t
s
=
empty_slicer
(
t
)
s
[
dim
]
=
slice
(
start
,
end
)
return
t
[
s
]
def
flip_z_cache_
(
z_cache
,
z
):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3
=
slice_tensor
(
z_cache
,
half_n
,
None
,
row_dim
)
z_cache
=
z_cache
.
transpose
(
row_dim
,
col_dim
)
# If n is odd, we need to shrink the z_cache by one row
z_cache
=
z_cache
[...,
:(
n
//
2
),
:,
:]
# Move the 3rd quadrant of z into the
first_half_slicer
=
empty_slicer
(
z_cache
)
first_half_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
[
first_half_slicer
]
=
quadrant_3
# Get the fourth quadrant of z
quadrant_4
=
slice_tensor
(
z
,
half_n
,
None
,
row_dim
)
quadrant_4
=
slice_tensor
(
quadrant_4
,
half_n
,
None
,
col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer
=
empty_slicer
(
z_cache
)
quadrant_3_slicer
[
col_dim
]
=
slice
(
half_n
,
None
)
z_cache
[
quadrant_3_slicer
]
=
quadrant_4
return
z_cache
# Initialize the z cache to the left half of z.
z_cache_shape
=
list
(
z
.
shape
)
z_cache_shape
[
col_dim
]
=
half_n
z_cache
=
z
.
new_zeros
(
z_cache_shape
)
z_cache_slicer
=
empty_slicer
(
z_cache
)
z_cache_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
.
copy_
(
z
[
z_cache_slicer
])
z_cache_rotated
=
False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range
=
list
(
range
(
0
,
half_n
,
inplace_chunk_size
))
initial_offsets
=
[
i_2
-
i_1
for
i_1
,
i_2
in
zip
(
i_range
,
i_range
[
1
:]
+
[
half_n
])
]
after_half
=
list
(
range
(
half_n
,
n
,
inplace_chunk_size
))
after_half_offsets
=
[
inplace_chunk_size
for
_
in
after_half
]
combined_range_with_offsets
=
zip
(
i_range
+
after_half
,
initial_offsets
+
after_half_offsets
)
for
i
,
offset
in
combined_range_with_offsets
:
if
(
not
z_cache_rotated
and
i
>=
half_n
):
z_cache
=
flip_z_cache_
(
z_cache
,
z
)
z_cache_rotated
=
True
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
b_chunk_dim
,
)
mask_chunk
=
slice_tensor
(
mask
,
i
,
i
+
offset
,
b_chunk_dim
,
)
z_chunk_b
=
z_chunk_b
.
clone
()
if
(
b_chunk_dim
==
col_dim
):
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
else
:
# b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if
(
not
z_cache_rotated
):
z_chunk_slicer
=
empty_slicer
(
z_chunk_b
)
z_chunk_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_chunk_b
[
z_chunk_slicer
]
=
slice_tensor
(
z_cache
,
i
,
i
+
offset
,
row_dim
,
)
else
:
z_cache_offset
=
i
-
half_n
z_chunk_b
=
slice_tensor
(
z_cache
,
z_cache_offset
,
z_cache_offset
+
offset
,
row_dim
)
b_chunk
=
compute_projection
(
z_chunk_b
,
mask_chunk
,
a
=
False
,
chunked
=
False
)
del
z_chunk_b
x_chunk
=
torch
.
matmul
(
a
,
b_chunk
,
)
x_chunk
=
permute_final_dims
(
x_chunk
,
(
1
,
2
,
0
))
x_chunk
=
self
.
layer_norm_out
(
x_chunk
)
x_chunk
=
self
.
linear_z
(
x_chunk
)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
g_chunk
=
self
.
linear_g
(
self
.
layer_norm_in
(
z_chunk_g
))
g_chunk
.
sigmoid_
()
del
z_chunk_g
x_chunk
*=
g_chunk
# Write the columns into z in-place
z_slicer
=
empty_slicer
(
z
)
z_slicer
[
col_dim
]
=
slice
(
i
,
i
+
offset
)
if
(
with_add
):
z
[
z_slicer
]
+=
x_chunk
else
:
z
[
z_slicer
]
=
x_chunk
else
:
b
=
compute_projection
(
z
,
mask
,
False
,
False
)
x
=
torch
.
matmul
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -71,57 +371,52 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
inplace_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
x
=
self
.
_combine_projections
(
a
,
b
)
a
=
mask
a
=
a
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
self
.
linear_a_p
(
z
)
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
x
=
x
*
g
return
z
return
x
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 11.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
openfold/np/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
openfold/np/protein.py
View file @
39a6d0e6
...
...
@@ -16,8 +16,9 @@
"""Protein data type."""
import
dataclasses
import
io
from
typing
import
Any
,
Mapping
,
Optional
from
typing
import
Any
,
Sequence
,
Mapping
,
Optional
import
re
import
string
from
openfold.np
import
residue_constants
from
Bio.PDB
import
PDBParser
...
...
@@ -51,16 +52,25 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index
:
Optional
[
np
.
ndarray
]
=
None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark
:
Optional
[
str
]
=
None
# Templates used to generate this protein (prediction-only)
parents
:
Optional
[
Sequence
[
str
]]
=
None
# Chain corresponding to each parent
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
...
...
@@ -104,7 +114,6 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
...
...
@@ -129,17 +138,32 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
continue
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
]
)
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
parents
=
None
parents_chain_index
=
None
if
(
"PARENT"
in
pdb_str
):
parents
=
[]
parents_chain_index
=
[]
chain_id
=
0
for
l
in
pdb_str
.
split
(
"
\n
"
):
if
(
"PARENT"
in
l
):
if
(
not
"N/A"
in
l
):
parent_names
=
l
.
split
()[
1
:]
parents
.
extend
(
parent_names
)
parents_chain_index
.
extend
([
chain_id
for
_
in
parent_names
])
chain_id
+=
1
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
string
.
ascii_uppercase
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
...
...
@@ -149,6 +173,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
parents
=
parents
,
parents_chain_index
=
parents_chain_index
,
)
...
...
@@ -213,6 +239,78 @@ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
)
def
get_pdb_headers
(
prot
:
Protein
,
chain_id
:
int
=
0
)
->
Sequence
[
str
]:
pdb_headers
=
[]
remark
=
prot
.
remark
if
(
remark
is
not
None
):
pdb_headers
.
append
(
f
"REMARK
{
remark
}
"
)
parents
=
prot
.
parents
parents_chain_index
=
prot
.
parents_chain_index
if
(
parents_chain_index
is
not
None
):
parents
=
[
p
for
i
,
p
in
zip
(
parents_chain_index
,
parents
)
if
i
==
chain_id
]
if
(
parents
is
None
or
len
(
parents
)
==
0
):
parents
=
[
"N/A"
]
pdb_headers
.
append
(
f
"PARENT
{
' '
.
join
(
parents
)
}
"
)
return
pdb_headers
def
add_pdb_headers
(
prot
:
Protein
,
pdb_str
:
str
)
->
str
:
""" Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines
=
[]
lines
=
pdb_str
.
split
(
'
\n
'
)
remark
=
prot
.
remark
if
(
remark
is
not
None
):
out_pdb_lines
.
append
(
f
"REMARK
{
remark
}
"
)
parents_per_chain
=
None
if
(
prot
.
parents
is
not
None
and
len
(
prot
.
parents
)
>
0
):
parents_per_chain
=
[]
if
(
prot
.
parents_chain_index
is
not
None
):
cur_chain
=
prot
.
parents_chain_index
[
0
]
parent_dict
=
{}
for
p
,
i
in
zip
(
prot
.
parents
,
prot
.
parents_chain_index
):
parent_dict
.
setdefault
(
str
(
i
),
[])
parent_dict
[
str
(
i
)].
append
(
p
)
max_idx
=
max
([
int
(
chain_idx
)
for
chain_idx
in
parent_dict
])
for
i
in
range
(
max_idx
+
1
):
chain_parents
=
parent_dict
.
get
(
str
(
i
),
[
"N/A"
])
parents_per_chain
.
append
(
chain_parents
)
else
:
parents_per_chain
.
append
(
prot
.
parents
)
else
:
parents_per_chain
=
[[
"N/A"
]]
make_parent_line
=
lambda
p
:
f
"PARENT
{
' '
.
join
(
p
)
}
"
out_pdb_lines
.
append
(
make_parent_line
(
parents_per_chain
[
0
]))
chain_counter
=
0
for
i
,
l
in
enumerate
(
lines
):
if
(
"PARENT"
not
in
l
and
"REMARK"
not
in
l
):
out_pdb_lines
.
append
(
l
)
if
(
"TER"
in
l
and
not
"END"
in
lines
[
i
+
1
]):
chain_counter
+=
1
if
(
not
chain_counter
>=
len
(
parents_per_chain
)):
chain_parents
=
parents_per_chain
[
chain_counter
]
else
:
chain_parents
=
[
"N/A"
]
out_pdb_lines
.
append
(
make_parent_line
(
chain_parents
))
return
'
\n
'
.
join
(
out_pdb_lines
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
...
...
@@ -232,8 +330,8 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
...
...
@@ -247,9 +345,17 @@ def to_pdb(prot: Protein) -> str:
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
headers
=
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_lines
.
extend
(
headers
)
pdb_lines
.
append
(
"MODEL 1"
)
n
=
aatype
.
shape
[
0
]
atom_index
=
1
last_chain_index
=
chain_index
[
0
]
prev_chain_index
=
0
chain_tags
=
string
.
ascii_uppercase
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
...
...
@@ -281,10 +387,17 @@ def to_pdb(prot: Protein) -> str:
0
]
# Protein supports only C, N, O, S, this works.
charge
=
""
chain_tag
=
"A"
if
(
chain_index
is
not
None
):
chain_tag
=
chain_tags
[
chain_index
[
i
]]
# PDB is a columnar format, every space matters here!
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
#TODO: check this refactor, chose main branch version
#f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f
"
{
res_name_3
:
>
3
}
{
chain_tag
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
...
...
@@ -293,16 +406,28 @@ def to_pdb(prot: Protein) -> str:
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
# Close the final chain.
pdb_lines
.
append
(
_chain_end
(
atom_index
,
res_1to3
(
aatype
[
-
1
]),
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
)
)
should_terminate
=
(
i
==
n
-
1
)
if
(
chain_index
is
not
None
):
if
(
i
!=
n
-
1
and
chain_index
[
i
+
1
]
!=
prev_chain_index
):
should_terminate
=
True
prev_chain_index
=
chain_index
[
i
+
1
]
if
(
should_terminate
):
# Close the chain.
chain_end
=
"TER"
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
"
f
"
{
res_1to3
(
aatype
[
i
]):
>
3
}
"
f
"
{
chain_tag
:
>
1
}{
residue_index
[
i
]:
>
4
}
"
)
pdb_lines
.
append
(
chain_termination_line
)
atom_index
+=
1
if
(
i
!=
n
-
1
):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines
.
extend
(
get_pdb_headers
(
prot
,
prev_chain_index
))
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
...
...
@@ -332,6 +457,9 @@ def from_prediction(
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
True
,
remark
:
Optional
[
str
]
=
None
,
parents
:
Optional
[
Sequence
[
str
]]
=
None
,
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
)
->
Protein
:
"""Assembles a protein from a prediction.
...
...
@@ -341,7 +469,9 @@ def from_prediction(
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
...
...
@@ -349,7 +479,7 @@ def from_prediction(
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
-
1
else
:
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
...
...
@@ -363,6 +493,9 @@ def from_prediction(
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
parents
,
parents_chain_index
=
parents_chain_index
,
)
openfold/np/relax/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
openfold/np/relax/amber_minimize.py
View file @
39a6d0e6
...
...
@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
from
openfold.np.relax
import
cleanup
,
utils
import
ml_collections
import
numpy
as
np
from
simtk
import
openmm
from
simtk
import
unit
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
try
:
# openmm >= 7.6
import
openmm
from
openmm
import
unit
from
openmm
import
app
as
openmm_app
from
openmm.app.internal.pdbstructure
import
PdbStructure
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk
import
openmm
from
simtk
import
unit
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
ENERGY
=
unit
.
kilocalories_per_mole
LENGTH
=
unit
.
angstroms
...
...
@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
if
checks
:
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_string
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
pdb_string
])
return
pdb_string
...
...
@@ -511,6 +524,9 @@ def run_pipeline(
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
# We keep the input around to restore metadata deleted by the relaxer
input_prot
=
prot
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
violations
=
np
.
inf
...
...
@@ -527,6 +543,11 @@ def run_pipeline(
max_attempts
=
max_attempts
,
use_gpu
=
use_gpu
,
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
ret
[
"min_pdb"
]
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
ret
[
"min_pdb"
]])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
if
place_hydrogens_every_iteration
:
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
...
...
openfold/np/relax/cleanup.py
View file @
39a6d0e6
...
...
@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import
io
import
pdbfixer
from
simtk.openmm
import
app
from
simtk.openmm.app
import
element
try
:
# openmm >= 7.6
from
openmm
import
app
from
openmm.app
import
element
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk.openmm
import
app
from
simtk.openmm.app
import
element
def
fix_pdb
(
pdbfile
,
alterations_info
):
...
...
openfold/np/relax/relax.py
View file @
39a6d0e6
...
...
@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations
=
out
[
"structural_violations"
][
"total_per_residue_violations_mask"
]
min_pdb
=
protein
.
add_pdb_headers
(
prot
,
min_pdb
)
return
min_pdb
,
debug_data
,
violations
openfold/np/relax/utils.py
View file @
39a6d0e6
...
...
@@ -18,8 +18,14 @@ import io
from
openfold.np
import
residue_constants
from
Bio
import
PDB
import
numpy
as
np
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
try
:
# openmm >= 7.6
from
openmm
import
app
as
openmm_app
from
openmm.app.internal.pdbstructure
import
PdbStructure
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
...
...
openfold/np/residue_constants.py
View file @
39a6d0e6
...
...
@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
int
)
restype_atom37_to_rigid_group
=
np
.
zeros
([
21
,
37
],
dtype
=
int
)
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
restype_atom37_rigid_group_positions
=
np
.
zeros
([
21
,
37
,
3
],
dtype
=
np
.
float32
)
restype_atom14_to_rigid_group
=
np
.
zeros
([
21
,
14
],
dtype
=
np
.
int
)
restype_atom14_to_rigid_group
=
np
.
zeros
([
21
,
14
],
dtype
=
int
)
restype_atom14_mask
=
np
.
zeros
([
21
,
14
],
dtype
=
np
.
float32
)
restype_atom14_rigid_group_positions
=
np
.
zeros
([
21
,
14
,
3
],
dtype
=
np
.
float32
)
restype_rigid_group_default_frame
=
np
.
zeros
([
21
,
8
,
4
,
4
],
dtype
=
np
.
float32
)
...
...
@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms_swap_idx
=
np
.
tile
(
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
)
np
.
arange
(
14
,
dtype
=
int
),
(
21
,
1
)
)
...
...
openfold/utils/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
from
.
import
kernel
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
+
[
"kernel"
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment